forked from quic-go/quic-go
add an integration test for 0-RTT GET requests (#4386)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,12 +13,14 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -86,7 +89,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
server = &http3.Server{
|
||||
Handler: mux,
|
||||
TLSConfig: getTLSConfig(),
|
||||
QuicConfig: getQuicConfig(nil),
|
||||
QuicConfig: getQuicConfig(&quic.Config{Allow0RTT: true}),
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
@@ -112,9 +115,12 @@ var _ = Describe("HTTP tests", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &http3.RoundTripper{
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
QuicConfig: getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: 10 * time.Second,
|
||||
Allow0RTT: true,
|
||||
}),
|
||||
DisableCompression: true,
|
||||
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
}
|
||||
client = &http.Client{Transport: rt}
|
||||
})
|
||||
@@ -572,4 +578,67 @@ var _ = Describe("HTTP tests", func() {
|
||||
}})
|
||||
Expect(err).To(MatchError(err))
|
||||
})
|
||||
|
||||
Context("0-RTT", func() {
|
||||
runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) {
|
||||
var num0RTTPackets atomic.Uint32
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
|
||||
if contains0RTTPacket(data) {
|
||||
num0RTTPackets.Add(1)
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return proxy, &num0RTTPackets
|
||||
}
|
||||
|
||||
It("sends 0-RTT GET requests", func() {
|
||||
proxy, num0RTTPackets := runCountingProxy(port, scaleDuration(50*time.Millisecond))
|
||||
defer proxy.Close()
|
||||
|
||||
tlsConf := getTLSClientConfigWithoutServerName()
|
||||
puts := make(chan string, 10)
|
||||
tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts)
|
||||
rt := &http3.RoundTripper{
|
||||
TLSClientConfig: tlsConf,
|
||||
QuicConfig: getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: 10 * time.Second,
|
||||
Allow0RTT: true,
|
||||
}),
|
||||
DisableCompression: true,
|
||||
}
|
||||
defer rt.Close()
|
||||
|
||||
mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete)))
|
||||
})
|
||||
req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxy.LocalPort()), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rsp, err := rt.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(data)).To(Equal("false"))
|
||||
Expect(num0RTTPackets.Load()).To(BeZero())
|
||||
Eventually(puts).Should(Receive())
|
||||
|
||||
rt2 := &http3.RoundTripper{
|
||||
TLSClientConfig: rt.TLSClientConfig,
|
||||
QuicConfig: rt.QuicConfig,
|
||||
DisableCompression: true,
|
||||
}
|
||||
defer rt2.Close()
|
||||
rsp, err = rt2.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
|
||||
data, err = io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(data)).To(Equal("true"))
|
||||
Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -32,13 +32,17 @@ var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
|
||||
session, ok := c.cache.Get(sessionKey)
|
||||
c.gets <- sessionKey
|
||||
if c.gets != nil {
|
||||
c.gets <- sessionKey
|
||||
}
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||
c.cache.Put(sessionKey, cs)
|
||||
c.puts <- sessionKey
|
||||
if c.puts != nil {
|
||||
c.puts <- sessionKey
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("TLS session resumption", func() {
|
||||
|
||||
@@ -50,6 +50,23 @@ func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionSt
|
||||
m.cache.Put(key, session)
|
||||
}
|
||||
|
||||
// contains0RTTPacket says if a packet contains a 0-RTT long header packet.
|
||||
// It correctly handles coalesced packets.
|
||||
func contains0RTTPacket(data []byte) bool {
|
||||
for len(data) > 0 {
|
||||
if !wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
hdr, _, rest, err := wire.ParsePacket(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
return true
|
||||
}
|
||||
data = rest
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ = Describe("0-RTT", func() {
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
|
||||
@@ -58,23 +75,13 @@ var _ = Describe("0-RTT", func() {
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
|
||||
for len(data) > 0 {
|
||||
if !wire.IsLongHeaderPacket(data[0]) {
|
||||
break
|
||||
}
|
||||
hdr, _, rest, err := wire.ParsePacket(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
num0RTTPackets.Add(1)
|
||||
break
|
||||
}
|
||||
data = rest
|
||||
if contains0RTTPacket(data) {
|
||||
num0RTTPackets.Add(1)
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return proxy, &num0RTTPackets
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user