From a19f99e98b10f5b0514afeeb18a86fbd5e6f470c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 24 Mar 2024 07:26:02 +1000 Subject: [PATCH] add an integration test for 0-RTT GET requests (#4386) --- integrationtests/self/http_test.go | 75 +++++++++++++++++++++++- integrationtests/self/resumption_test.go | 8 ++- integrationtests/self/zero_rtt_test.go | 31 ++++++---- 3 files changed, 97 insertions(+), 17 deletions(-) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 10564c17..bff99d34 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -5,6 +5,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/tls" "errors" "fmt" "io" @@ -12,12 +13,14 @@ import ( "net/http" "os" "strconv" + "sync/atomic" "time" "golang.org/x/sync/errgroup" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -86,7 +89,7 @@ var _ = Describe("HTTP tests", func() { server = &http3.Server{ Handler: mux, TLSConfig: getTLSConfig(), - QuicConfig: getQuicConfig(nil), + QuicConfig: getQuicConfig(&quic.Config{Allow0RTT: true}), } addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") @@ -112,9 +115,12 @@ var _ = Describe("HTTP tests", func() { BeforeEach(func() { rt = &http3.RoundTripper{ - TLSClientConfig: getTLSClientConfigWithoutServerName(), + TLSClientConfig: getTLSClientConfigWithoutServerName(), + QuicConfig: getQuicConfig(&quic.Config{ + MaxIdleTimeout: 10 * time.Second, + Allow0RTT: true, + }), DisableCompression: true, - QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), } client = &http.Client{Transport: rt} }) @@ -572,4 +578,67 @@ var _ = Describe("HTTP tests", func() { }}) Expect(err).To(MatchError(err)) }) + + Context("0-RTT", func() { + runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) { + var num0RTTPackets atomic.Uint32 + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + if contains0RTTPacket(data) { + num0RTTPackets.Add(1) + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + return proxy, &num0RTTPackets + } + + It("sends 0-RTT GET requests", func() { + proxy, num0RTTPackets := runCountingProxy(port, scaleDuration(50*time.Millisecond)) + defer proxy.Close() + + tlsConf := getTLSClientConfigWithoutServerName() + puts := make(chan string, 10) + tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts) + rt := &http3.RoundTripper{ + TLSClientConfig: tlsConf, + QuicConfig: getQuicConfig(&quic.Config{ + MaxIdleTimeout: 10 * time.Second, + Allow0RTT: true, + }), + DisableCompression: true, + } + defer rt.Close() + + mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete))) + }) + req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxy.LocalPort()), nil) + Expect(err).ToNot(HaveOccurred()) + rsp, err := rt.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(BeEquivalentTo(200)) + data, err := io.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("false")) + Expect(num0RTTPackets.Load()).To(BeZero()) + Eventually(puts).Should(Receive()) + + rt2 := &http3.RoundTripper{ + TLSClientConfig: rt.TLSClientConfig, + QuicConfig: rt.QuicConfig, + DisableCompression: true, + } + defer rt2.Close() + rsp, err = rt2.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(BeEquivalentTo(200)) + data, err = io.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("true")) + Expect(num0RTTPackets.Load()).To(BeNumerically(">", 0)) + }) + }) }) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 23f241be..f8bb3a00 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -32,13 +32,17 @@ var _ tls.ClientSessionCache = &clientSessionCache{} func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) { session, ok := c.cache.Get(sessionKey) - c.gets <- sessionKey + if c.gets != nil { + c.gets <- sessionKey + } return session, ok } func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { c.cache.Put(sessionKey, cs) - c.puts <- sessionKey + if c.puts != nil { + c.puts <- sessionKey + } } var _ = Describe("TLS session resumption", func() { diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 67d1f8af..46b3786d 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -50,6 +50,23 @@ func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionSt m.cache.Put(key, session) } +// contains0RTTPacket says if a packet contains a 0-RTT long header packet. +// It correctly handles coalesced packets. +func contains0RTTPacket(data []byte) bool { + for len(data) > 0 { + if !wire.IsLongHeaderPacket(data[0]) { + return false + } + hdr, _, rest, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + return true + } + data = rest + } + return false +} + var _ = Describe("0-RTT", func() { rtt := scaleDuration(5 * time.Millisecond) @@ -58,23 +75,13 @@ var _ = Describe("0-RTT", func() { proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - for len(data) > 0 { - if !wire.IsLongHeaderPacket(data[0]) { - break - } - hdr, _, rest, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - num0RTTPackets.Add(1) - break - } - data = rest + if contains0RTTPacket(data) { + num0RTTPackets.Add(1) } return rtt / 2 }, }) Expect(err).ToNot(HaveOccurred()) - return proxy, &num0RTTPackets }