From 892851eb8cf7b57e74b0598dbb05c14d516b39ab Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 22 Jan 2024 21:04:25 -0800 Subject: [PATCH] add Transport config options to limit the number of handshakes (#4248) * add Transport config options to limit the number of handshakes * fix accounting for failed handshakes * increase handshake limits, improve documentation --- integrationtests/self/handshake_test.go | 177 +++++++++++++++-- server.go | 145 ++++++++------ server_test.go | 244 +++++++++++++++++------- transport.go | 42 ++++ 4 files changed, 466 insertions(+), 142 deletions(-) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index ce27949f..73f0973f 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -14,6 +15,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -301,7 +303,7 @@ var _ = Describe("Handshake tests", func() { }) }) - Context("rate limiting", func() { + Context("queuening and accepting connections", func() { var ( server *quic.Listener pconn net.PacketConn @@ -343,8 +345,11 @@ var _ = Describe("Handshake tests", func() { } time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued - _, err := dial() - Expect(err).To(HaveOccurred()) + conn, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn.AcceptStream(ctx) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) @@ -353,18 +358,21 @@ var _ = Describe("Handshake tests", func() { _, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // dial again, and expect that this dial succeeds - conn, err := dial() + conn2, err := dial() Expect(err).ToNot(HaveOccurred()) - defer conn.CloseWithError(0, "") + defer conn2.CloseWithError(0, "") time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued - _, err = dial() - Expect(err).To(HaveOccurred()) + conn3, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn3.AcceptStream(ctx) Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) }) - It("removes closed connections from the accept queue", func() { + It("also returns closed connections from the accept queue", func() { firstConn, err := dial() Expect(err).ToNot(HaveOccurred()) @@ -375,8 +383,11 @@ var _ = Describe("Handshake tests", func() { } time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued - _, err = dial() - Expect(err).To(HaveOccurred()) + conn, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn.AcceptStream(ctx) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) @@ -388,8 +399,11 @@ var _ = Describe("Handshake tests", func() { time.Sleep(scaleDuration(200 * time.Millisecond)) // dial again, and expect that this fails again - _, err = dial() - Expect(err).To(HaveOccurred()) + conn2, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn2.AcceptStream(ctx) Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) @@ -448,6 +462,145 @@ var _ = Describe("Handshake tests", func() { }) }) + Context("limiting handshakes", func() { + var conn *net.UDPConn + + BeforeEach(func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { conn.Close() }) + + It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() { + const limit = 3 + tr := quic.Transport{ + Conn: conn, + MaxUnvalidatedHandshakes: limit, + } + defer tr.Close() + + // Block all handshakes. + handshakes := make(chan struct{}) + var tlsConf tls.Config + tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + handshakes <- struct{}{} + return getTLSConfig(), nil + } + ln, err := tr.Listen(&tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + const additional = 2 + results := make([]struct{ retry, closed atomic.Bool }, limit+additional) + // Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel. + // Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and + // exactly 2 to experience a Retry. + for i := 0; i < limit+additional; i++ { + go func(index int) { + defer GinkgoRecover() + quicConf := getQuicConfig(&quic.Config{ + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return &logging.ConnectionTracer{ + ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) }, + ClosedConnection: func(error) { results[index].closed.Store(true) }, + } + }, + }) + conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf) + Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") + }(i) + } + numRetries := func() (n int) { + for i := 0; i < limit+additional; i++ { + if results[i].retry.Load() { + n++ + } + } + return + } + numClosed := func() (n int) { + for i := 0; i < limit+2; i++ { + if results[i].closed.Load() { + n++ + } + } + return + } + Eventually(numRetries).Should(Equal(additional)) + // allow the handshakes to complete + for i := 0; i < limit+additional; i++ { + Eventually(handshakes).Should(Receive()) + } + Eventually(numClosed).Should(Equal(limit + additional)) + Expect(numRetries()).To(Equal(additional)) // just to be on the safe side + }) + + It("rejects connections when the number of handshakes reaches MaxHandshakes", func() { + const limit = 3 + tr := quic.Transport{ + Conn: conn, + MaxHandshakes: limit, + } + defer tr.Close() + + // Block all handshakes. + handshakes := make(chan struct{}) + var tlsConf tls.Config + tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + handshakes <- struct{}{} + return getTLSConfig(), nil + } + ln, err := tr.Listen(&tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + const additional = 2 + // Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel. + // Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and + // exactly 2 to experience a Retry. + var numSuccessful, numFailed atomic.Int32 + for i := 0; i < limit+additional; i++ { + go func() { + defer GinkgoRecover() + quicConf := getQuicConfig(&quic.Config{ + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return &logging.ConnectionTracer{ + ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") }, + } + }, + }) + conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf) + if err != nil { + var transportErr *quic.TransportError + if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused { + Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err)) + } + numFailed.Add(1) + return + } + numSuccessful.Add(1) + conn.CloseWithError(0, "") + }() + } + Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional)) + // allow the handshakes to complete + for i := 0; i < limit; i++ { + Eventually(handshakes).Should(Receive()) + } + Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit)) + + // make sure that the server is reachable again after these handshakes have completed + go func() { <-handshakes }() // allow this handshake to complete immediately + conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") + }) + }) + Context("ALPN", func() { It("negotiates an application protocol", func() { ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) diff --git a/server.go b/server.go index e5c32de3..87bdd6c2 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go/internal/handshake" @@ -110,6 +111,11 @@ type baseServer struct { connectionRefusedQueue chan rejectedPacket retryQueue chan rejectedPacket + maxNumHandshakesUnvalidated int + maxNumHandshakesTotal int + numHandshakesUnvalidated atomic.Int64 + numHandshakesValidated atomic.Int64 + connQueue chan quicConn tracer *logging.Tracer @@ -238,31 +244,34 @@ func newServer( onClose func(), tokenGeneratorKey TokenGeneratorKey, maxTokenAge time.Duration, + maxNumHandshakesUnvalidated, maxNumHandshakesTotal int, disableVersionNegotiation bool, acceptEarly bool, ) *baseServer { s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), - maxTokenAge: maxTokenAge, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan receivedPacket, 4), - invalidTokenQueue: make(chan rejectedPacket, 4), - connectionRefusedQueue: make(chan rejectedPacket, 4), - retryQueue: make(chan rejectedPacket, 8), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - disableVersionNegotiation: disableVersionNegotiation, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), + maxTokenAge: maxTokenAge, + maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated, + maxNumHandshakesTotal: maxNumHandshakesTotal, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan rejectedPacket, 4), + connectionRefusedQueue: make(chan rejectedPacket, 4), + retryQueue: make(chan rejectedPacket, 8), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + disableVersionNegotiation: disableVersionNegotiation, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -570,8 +579,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } } - clientAddrIsValid := s.validateToken(token, p.remoteAddr) - if token != nil && !clientAddrIsValid { + clientAddrValidated := s.validateToken(token, p.remoteAddr) + if token != nil && !clientAddrValidated { // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // We just ignore them, and act as if there was no token on this packet at all. // This also means we might send a Retry later. @@ -590,7 +599,25 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error return nil } } - if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { + + // Until the next call to handleInitialImpl, these numbers are guaranteed to not increase. + // They might decrease if another connection completes the handshake. + numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load() + numHandshakesValidated := s.numHandshakesValidated.Load() + + // Check the total handshake limit first. It's better to reject than to initiate a retry. + if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) { + s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal) + delete(s.zeroRTTQueues, hdr.DestConnectionID) + select { + case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: + default: + // drop packet if we can't send out the CONNECTION_REFUSED fast enough + p.buffer.Release() + } + return nil + } + if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) { // Retry invalidates all 0-RTT packets sent. delete(s.zeroRTTQueues, hdr.DestConnectionID) select { @@ -602,17 +629,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error return nil } - if queueLen := len(s.connQueue); queueLen >= protocol.MaxAcceptQueueSize { - s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) - select { - case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: - default: - // drop packet if we can't send out the CONNECTION_REFUSED fast enough - p.buffer.Release() - } - return nil - } - connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err @@ -652,7 +668,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error config, s.tlsConf, s.tokenGenerator, - clientAddrIsValid, + clientAddrValidated, tracer, tracingID, s.logger, @@ -677,8 +693,31 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } return nil } + if clientAddrValidated { + s.numHandshakesValidated.Add(1) + } else { + s.numHandshakesUnvalidated.Add(1) + } go conn.run() - go s.handleNewConn(conn) + go func() { + completed := s.handleNewConn(conn) + if clientAddrValidated { + if s.numHandshakesValidated.Add(-1) < 0 { + panic("server BUG: number of validated handshakes negative") + } + } else if s.numHandshakesUnvalidated.Add(-1) < 0 { + panic("server BUG: number of unvalidated handshakes negative") + } + if !completed { + return + } + + select { + case s.connQueue <- conn: + default: + conn.closeWithTransportError(ConnectionRefused) + } + }() if conn == nil { p.buffer.Release() return nil @@ -686,34 +725,28 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error return nil } -func (s *baseServer) handleNewConn(conn quicConn) { - connCtx := conn.Context() +func (s *baseServer) handleNewConn(conn quicConn) bool { if s.acceptEarlyConns { // wait until the early connection is ready, the handshake fails, or the server is closed select { case <-s.errorChan: conn.closeWithTransportError(ConnectionRefused) - return + return false + case <-conn.Context().Done(): + return false case <-conn.earlyConnReady(): - case <-connCtx.Done(): - return - } - } else { - // wait until the handshake is complete (or fails) - select { - case <-s.errorChan: - conn.closeWithTransportError(ConnectionRefused) - return - case <-conn.HandshakeComplete(): - case <-connCtx.Done(): - return + return true } } - + // wait until the handshake completes, fails, or the server is closed select { - case s.connQueue <- conn: - default: - conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) + case <-s.errorChan: + conn.closeWithTransportError(ConnectionRefused) + return false + case <-conn.Context().Done(): + return false + case <-conn.HandshakeComplete(): + return true } } diff --git a/server_test.go b/server_test.go index 5500cd09..6bd8e034 100644 --- a/server_test.go +++ b/server_test.go @@ -83,6 +83,25 @@ var _ = Describe("Server", func() { return hdr } + checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) + extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + } + BeforeEach(func() { conn = NewMockPacketConn(mockCtrl) conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() @@ -534,6 +553,9 @@ var _ = Describe("Server", func() { }) It("drops packets if the receive queue is full", func() { + serv.maxNumHandshakesTotal = 10000 + serv.maxNumHandshakesUnvalidated = 10000 + phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) @@ -542,7 +564,7 @@ var _ = Describe("Server", func() { }).AnyTimes() acceptConn := make(chan struct{}) - var counter uint32 // to be used as an atomic, so we query it in Eventually + var counter atomic.Uint32 serv.newConn = func( _ sendConn, runner connRunner, @@ -563,7 +585,7 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { <-acceptConn - atomic.AddUint32(&counter, 1) + counter.Add(1) conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) conn.EXPECT().run().MaxTimes(1) @@ -590,10 +612,10 @@ var _ = Describe("Server", func() { close(acceptConn) Eventually( - func() uint32 { return atomic.LoadUint32(&counter) }, + func() uint32 { return counter.Load() }, scaleDuration(100*time.Millisecond), ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) - Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) + Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) }) It("only creates a single connection for a duplicate Initial", func() { @@ -633,7 +655,20 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("rejects new connection attempts if the accept queue is full", func() { + It("limits the number of unvalidated handshakes", func() { + const limit = 3 + serv.maxNumHandshakesTotal = 10000 + serv.maxNumHandshakesUnvalidated = limit + + phm.EXPECT().Get(gomock.Any()).AnyTimes() + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok + }).AnyTimes() + + handshakeChan := make(chan struct{}) + connChan := make(chan *MockQUICConn, 1) serv.newConn = func( _ sendConn, runner connRunner, @@ -653,73 +688,140 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) quicConn { - conn := NewMockQUICConn(mockCtrl) + conn := <-connChan conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().Context().Return(context.Background()) - c := make(chan struct{}) - close(c) - conn.EXPECT().HandshakeComplete().Return(c) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) return conn } - phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - _, ok := fn() - return ok - }).Times(protocol.MaxAcceptQueueSize) - - var wg sync.WaitGroup - wg.Add(protocol.MaxAcceptQueueSize) - for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - go func() { - defer GinkgoRecover() - defer wg.Done() - serv.handlePacket(getInitialWithRandomDestConnID()) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }() + // Initiate the maximum number of allowed connection attempts. + for i := 0; i < limit; i++ { + conn := NewMockQUICConn(mockCtrl) + connChan <- conn + serv.handlePacket(getInitialWithRandomDestConnID()) } - wg.Wait() + + // Now initiate another connection attempt. p := getInitialWithRandomDestConnID() - hdr, _, _, err := wire.ParsePacket(p.data) - Expect(err).ToNot(HaveOccurred()) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + defer GinkgoRecover() + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + }) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer GinkgoRecover() defer close(done) - rejectHdr := parseHeader(b) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + hdr, _, _, err := wire.ParsePacket(b) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) return len(b), nil }) serv.handlePacket(p) Eventually(done).Should(BeClosed()) + + close(handshakeChan) + for i := 0; i < limit; i++ { + _, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + } + for i := 0; i < limit; i++ { + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed + connChan <- conn + serv.handlePacket(getInitialWithRandomDestConnID()) + } + }) + + It("limits the number of total handshakes", func() { + const limit = 3 + serv.maxNumHandshakesTotal = limit + serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry + + phm.EXPECT().Get(gomock.Any()).AnyTimes() + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok + }).AnyTimes() + + handshakeChan := make(chan struct{}) + connChan := make(chan *MockQUICConn, 1) + serv.newConn = func( + _ sendConn, + runner connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ *logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn := <-connChan + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + return conn + } + + for i := 0; i < limit; i++ { + conn := NewMockQUICConn(mockCtrl) + connChan <- conn + serv.handlePacket(getInitialWithRandomDestConnID()) + } + + p := getInitialWithRandomDestConnID() + done := make(chan struct{}) + tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + defer GinkgoRecover() + hdr, _, _, err := wire.ParsePacket(p.data) + Expect(err).ToNot(HaveOccurred()) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused)) + }) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer GinkgoRecover() + defer close(done) + hdr, _, _, err := wire.ParsePacket(p.data) + Expect(err).ToNot(HaveOccurred()) + checkConnectionCloseError(b, hdr, qerr.ConnectionRefused) + return len(b), nil + }) + serv.handlePacket(p) + Eventually(done).Should(BeClosed()) + + close(handshakeChan) + for i := 0; i < limit; i++ { + _, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + } + for i := 0; i < limit; i++ { + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed + connChan <- conn + serv.handlePacket(getInitialWithRandomDestConnID()) + } }) }) Context("token validation", func() { - checkInvalidToken := func(b []byte, origHdr *wire.Header) { - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - } - It("decodes the token from the token field", func() { raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) @@ -771,7 +873,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) - checkInvalidToken(b, hdr) + checkConnectionCloseError(b, hdr, qerr.InvalidToken) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) @@ -809,7 +911,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) - checkInvalidToken(b, hdr) + checkConnectionCloseError(b, hdr, qerr.InvalidToken) return len(b), nil }) phm.EXPECT().Get(gomock.Any()) @@ -1186,8 +1288,7 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - + connChan := make(chan *MockQUICConn, 1) serv.baseServer.newConn = func( _ sendConn, runner connRunner, @@ -1209,7 +1310,7 @@ var _ = Describe("Server", func() { ) quicConn { ready := make(chan struct{}) close(ready) - conn := NewMockQUICConn(mockCtrl) + conn := <-connChan conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().earlyConnReady().Return(ready) @@ -1224,27 +1325,22 @@ var _ = Describe("Server", func() { return ok }).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + conn := NewMockQUICConn(mockCtrl) + connChan <- conn serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) } Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - p := getInitialWithRandomDestConnID() - hdr := parseHeader(p.data) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - rejectHdr := parseHeader(b) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - return len(b), nil + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok }) - serv.baseServer.handlePacket(p) - Eventually(done).Should(BeClosed()) + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().closeWithTransportError(ConnectionRefused) + connChan <- conn + serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) }) It("doesn't accept new connections if they were closed in the mean time", func() { diff --git a/transport.go b/transport.go index 6dee5184..5316754a 100644 --- a/transport.go +++ b/transport.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "errors" + "math" "net" "sync" "sync/atomic" @@ -18,6 +19,18 @@ import ( var errListenerAlreadySet = errors.New("listener already set") +const ( + // defaultMaxNumUnvalidatedHandshakes is the default value for Transport.MaxUnvalidatedHandshakes. + defaultMaxNumUnvalidatedHandshakes = 128 + // defaultMaxNumHandshakes is the default value for Transport.MaxHandshakes. + // It's not clear how to choose a reasonable value that works for all use cases. + // In production, implementations should: + // 1. Choose a lower value. + // 2. Implement some kind of IP-address based filtering using the Config.GetConfigForClient + // callback in order to prevent flooding attacks from a single / small number of IP addresses. + defaultMaxNumHandshakes = math.MaxInt32 +) + // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as @@ -77,6 +90,25 @@ type Transport struct { // It has no effect for clients. DisableVersionNegotiationPackets bool + // MaxUnvalidatedHandshakes is the maximum number of concurrent incoming QUIC handshakes + // originating from unvalidated source addresses. + // If the number of handshakes from unvalidated addresses reaches this number, new incoming + // connection attempts will need to proof reachability at the respective source address using the + // Retry mechanism, as described in RFC 9000 section 8.1.2. + // Validating the source address adds one additional network roundtrip to the handshake. + // If unset, a default value of 128 will be used. + // When set to a negative value, every connection attempt will need to validate the source address. + // It does not make sense to set this value higher than MaxHandshakes. + MaxUnvalidatedHandshakes int + // MaxHandshakes is the maximum number of concurrent incoming handshakes, both from validated + // and unvalidated source addresses. + // If unset, the number of concurrent handshakes will not be limited. + // Applications should choose a reasonable value based on their thread model, and consider + // implementing IP-based rate limiting using Config.GetConfigForClient. + // If the number of handshakes reaches this number, new connection attempts will be rejected by + // terminating the connection attempt using a CONNECTION_REFUSED error. + MaxHandshakes int + // A Tracer traces events that don't belong to a single QUIC connection. Tracer *logging.Tracer @@ -151,6 +183,14 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo if err := t.init(false); err != nil { return nil, err } + maxUnvalidatedHandshakes := t.MaxUnvalidatedHandshakes + if maxUnvalidatedHandshakes == 0 { + maxUnvalidatedHandshakes = defaultMaxNumUnvalidatedHandshakes + } + maxHandshakes := t.MaxHandshakes + if maxHandshakes == 0 { + maxHandshakes = defaultMaxNumHandshakes + } s := newServer( t.conn, t.handlerMap, @@ -161,6 +201,8 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.closeServer, *t.TokenGeneratorKey, t.MaxTokenAge, + maxUnvalidatedHandshakes, + maxHandshakes, t.DisableVersionNegotiationPackets, allow0RTT, )