From 459a6f3df90c512dfd3b4b4ab193c0c3a539759e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jun 2024 18:42:58 +0800 Subject: [PATCH] fix the server's 0-RTT rejection logic when using GetConfigForClient (#4550) --- http3/server.go | 4 + integrationtests/self/zero_rtt_test.go | 152 ++++++++++++------------ internal/handshake/crypto_setup.go | 42 +------ internal/handshake/crypto_setup_test.go | 74 ------------ internal/{handshake => qtls}/conn.go | 2 +- internal/qtls/qtls.go | 34 +++++- internal/qtls/qtls_test.go | 81 ++++++++++++- 7 files changed, 196 insertions(+), 193 deletions(-) rename internal/{handshake => qtls}/conn.go (97%) diff --git a/http3/server.go b/http3/server.go index 18853b89..7e6ffd54 100644 --- a/http3/server.go +++ b/http3/server.go @@ -94,6 +94,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { if config == nil { return nil, nil } + // Workaround for https://github.com/golang/go/issues/60506. + // This initializes the session tickets _before_ cloning the config. + _, _ = config.DecryptTicket(nil, tls.ConnectionState{}) + config = config.Clone() config.NextProtos = []string{proto} return config, nil diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 46b3786d..bf966b9c 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -813,87 +813,93 @@ var _ = Describe("0-RTT", func() { Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), ) - for _, l := range []int{0, 15} { - connIDLen := l + test0RTTRejection := func(tlsConf *tls.Config) { + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) + // now dial new connection with different transport parameters + counter, tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: 1, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { - tlsConf := getTLSConfig() - clientConf := getTLSClientConfig() - dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - // now dial new connection with different transport parameters - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: 1, - Tracer: newTracer(tracer), - }), - ) + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{}), + ) + Expect(err).ToNot(HaveOccurred()) + // The client remembers that it was allowed to open 2 uni-directional streams. + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := firstStr.Write([]byte("first flight")) Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() + }() + secondStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := secondStr.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - // The client remembers that it was allowed to open 2 uni-directional streams. - firstStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}, 2) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := firstStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() - secondStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := secondStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.AcceptStream(ctx) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + Eventually(written).Should(Receive()) + Eventually(written).Should(Receive()) + _, err = firstStr.Write([]byte("foobar")) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + _, err = conn.OpenUniStream() + Expect(err).To(MatchError(quic.Err0RTTRejected)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, err = conn.AcceptStream(ctx) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - Eventually(written).Should(Receive()) - Eventually(written).Should(Receive()) - _, err = firstStr.Write([]byte("foobar")) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - _, err = conn.OpenUniStream() - Expect(err).To(MatchError(quic.Err0RTTRejected)) + _, err = conn.AcceptStream(ctx) + Expect(err).To(Equal(quic.Err0RTTRejected)) - _, err = conn.AcceptStream(ctx) - Expect(err).To(Equal(quic.Err0RTTRejected)) + newConn := conn.NextConnection() + str, err := newConn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = newConn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + _, err = str.Write([]byte("second flight")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) - newConn := conn.NextConnection() - str, err := newConn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = newConn.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - _, err = str.Write([]byte("second flight")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := num0RTTPackets.Load() - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := num0RTTPackets.Load() + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) } + It("correctly deals with 0-RTT rejections", func() { + test0RTTRejection(getTLSConfig()) + }) + + It("correctly deals with 0-RTT rejections, when the server uses GetConfigForClient", func() { + tlsConf := getTLSConfig() + test0RTTRejection(&tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return tlsConf, nil }, + }) + }) + It("queues 0-RTT packets, if the Initial is delayed", func() { tlsConf := getTLSConfig() clientConf := getTLSClientConfig() diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index e8ed4e89..20bcc474 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -123,44 +123,12 @@ func NewCryptoSetupServer( ) cs.allow0RTT = allow0RTT - quicConf := &tls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForServer(quicConf, cs.getDataForSessionTicket, cs.handleSessionTicket) - addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) - - cs.tlsConf = quicConf.TLSConfig - cs.conn = tls.QUICServer(quicConf) - + tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket) + cs.tlsConf = tlsConf + cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf}) return cs } -// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. -// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn -// that allows the caller to get the local and the remote address. -func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) { - if conf.GetConfigForClient != nil { - gcfc := conf.GetConfigForClient - conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - c, err := gcfc(info) - if c != nil { - c = c.Clone() - // This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted. - c.MinVersion = tls.VersionTLS13 - // We're returning a tls.Config here, so we need to apply this recursively. - addConnToClientHelloInfo(c, localAddr, remoteAddr) - } - return c, err - } - } - if conf.GetCertificate != nil { - gc := conf.GetCertificate - conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - return gc(info) - } - } -} - func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, @@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte { // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ - EarlyData: h.allow0RTT, - }); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 0054867e..aa57f281 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -9,7 +9,6 @@ import ( "crypto/x509/pkix" "math/big" "net" - "reflect" "time" mocktls "github.com/quic-go/quic-go/internal/mocks/tls" @@ -106,79 +105,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) - Context("filling in a net.Conn in tls.ClientHelloInfo", func() { - var ( - local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} - remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - ) - - It("wraps GetCertificate", func() { - var localAddr, remoteAddr net.Addr - tlsConf := &tls.Config{ - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - cert := generateCert() - return &cert, nil - }, - } - addConnToClientHelloInfo(tlsConf, local, remote) - _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - }) - - It("wraps GetConfigForClient", func() { - var localAddr, remoteAddr net.Addr - tlsConf := &tls.Config{ - GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - return &tls.Config{}, nil - }, - } - addConnToClientHelloInfo(tlsConf, local, remote) - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - Expect(conf).ToNot(BeNil()) - Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) - }) - - It("wraps GetConfigForClient, recursively", func() { - var localAddr, remoteAddr net.Addr - tlsConf := &tls.Config{} - var innerConf *tls.Config - getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - cert := generateCert() - return &cert, nil - } - tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - innerConf = tlsConf.Clone() - // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config - innerConf.MaxVersion = tls.VersionTLS12 - innerConf.GetCertificate = getCert - return innerConf, nil - } - addConnToClientHelloInfo(tlsConf, local, remote) - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf).ToNot(BeNil()) - Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) - _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - // make sure that the tls.Config returned by GetConfigForClient isn't modified - Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) - Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) - }) - }) - Context("doing the handshake", func() { newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { rttStats := &utils.RTTStats{} diff --git a/internal/handshake/conn.go b/internal/qtls/conn.go similarity index 97% rename from internal/handshake/conn.go rename to internal/qtls/conn.go index 54af823b..6660ac66 100644 --- a/internal/handshake/conn.go +++ b/internal/qtls/conn.go @@ -1,4 +1,4 @@ -package handshake +package qtls import ( "net" diff --git a/internal/qtls/qtls.go b/internal/qtls/qtls.go index fdebd06e..cdfe82a2 100644 --- a/internal/qtls/qtls.go +++ b/internal/qtls/qtls.go @@ -4,20 +4,23 @@ import ( "bytes" "crypto/tls" "fmt" + "net" "github.com/quic-go/quic-go/internal/protocol" ) -func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { - conf := qconf.TLSConfig - +func SetupConfigForServer( + conf *tls.Config, + localAddr, remoteAddr net.Addr, + getData func() []byte, + handleSessionTicket func([]byte, bool) bool, +) *tls.Config { // Workaround for https://github.com/golang/go/issues/60506. // This initializes the session tickets _before_ cloning the config. _, _ = conf.DecryptTicket(nil, tls.ConnectionState{}) conf = conf.Clone() conf.MinVersion = tls.VersionTLS13 - qconf.TLSConfig = conf // add callbacks to save transport parameters into the session ticket origWrapSession := conf.WrapSession @@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSe return state, nil } + // The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. + // Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn + // that allows the caller to get the local and the remote address. + if conf.GetConfigForClient != nil { + gcfc := conf.GetConfigForClient + conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + c, err := gcfc(info) + if c != nil { + // We're returning a tls.Config here, so we need to apply this recursively. + c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket) + } + return c, err + } + } + if conf.GetCertificate != nil { + gc := conf.GetCertificate + conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + return gc(info) + } + } + return conf } func SetupConfigForClient( diff --git a/internal/qtls/qtls_test.go b/internal/qtls/qtls_test.go index 1acb6928..b041af74 100644 --- a/internal/qtls/qtls_test.go +++ b/internal/qtls/qtls_test.go @@ -2,6 +2,8 @@ package qtls import ( "crypto/tls" + "net" + "reflect" "github.com/quic-go/quic-go/internal/protocol" @@ -41,13 +43,86 @@ var _ = Describe("interface go crypto/tls", func() { }) Context("setting up a tls.Config for the server", func() { + var ( + local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + ) + It("sets the minimum TLS version to TLS 1.3", func() { orig := &tls.Config{MinVersion: tls.VersionTLS12} - conf := &tls.QUICConfig{TLSConfig: orig} - SetupConfigForServer(conf, nil, nil) - Expect(conf.TLSConfig.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) + conf := SetupConfigForServer(orig, local, remote, nil, nil) + Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) // check that the original config wasn't modified Expect(orig.MinVersion).To(BeEquivalentTo(tls.VersionTLS12)) }) + + It("wraps GetCertificate", func() { + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Certificate{}, nil + }, + } + conf := SetupConfigForServer(tlsConf, local, remote, nil, nil) + _, err := conf.GetCertificate(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + }) + + It("wraps GetConfigForClient", func() { + var localAddr, remoteAddr net.Addr + tlsConf := SetupConfigForServer( + &tls.Config{ + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Config{}, nil + }, + }, + local, + remote, + nil, + nil, + ) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + Expect(conf).ToNot(BeNil()) + Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) + }) + + It("wraps GetConfigForClient, recursively", func() { + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{} + var innerConf *tls.Config + getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Certificate{}, nil + } + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + innerConf = tlsConf.Clone() + // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config + innerConf.MaxVersion = tls.VersionTLS12 + innerConf.GetCertificate = getCert + return innerConf, nil + } + tlsConf = SetupConfigForServer(tlsConf, local, remote, nil, nil) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(conf).ToNot(BeNil()) + Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) + _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + // make sure that the tls.Config returned by GetConfigForClient isn't modified + Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) + Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) + }) }) })