From 90514d53d170ec2cdbed7c190ce88e1adbe974a1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 6 Jan 2019 13:36:52 +0700 Subject: [PATCH] reject new connection attempts if the server's accept queue is full --- integrationtests/self/handshake_test.go | 48 +++++++++++++++++ internal/handshake/crypto_setup.go | 2 +- internal/handshake/initial_aead.go | 3 +- internal/handshake/initial_aead_test.go | 12 ++--- internal/protocol/params.go | 4 ++ server.go | 72 +++++++++++++++++++++++-- server_test.go | 58 ++++++++++++++++++++ 7 files changed, 186 insertions(+), 13 deletions(-) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 23fb9f647..6d0f76ac7 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -4,9 +4,11 @@ import ( "crypto/tls" "fmt" "net" + "time" quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -143,4 +145,50 @@ var _ = Describe("Handshake tests", func() { }) } }) + + Context("rate limiting", func() { + It("rejects new connection attempts if connections don't get accepted", func() { + // start the server, but don't call Accept + serverConfig.AcceptCookie = func(net.Addr, *quic.Cookie) bool { + return true + } + server, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + dial := func() (quic.Session, error) { + return quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + &tls.Config{RootCAs: testdata.GetRootCA()}, + nil, + ) + } + + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + sess, err := dial() + Expect(err).ToNot(HaveOccurred()) + defer sess.Close() + } + time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued + + _, err = dial() + Expect(err).To(HaveOccurred()) + // TODO(#1567): use the SERVER_BUSY error code + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway)) + + // now accept one session, freeing one spot in the queue + _, err = server.Accept() + Expect(err).ToNot(HaveOccurred()) + // dial again, and expect that this dial succeeds + sess, err := dial() + Expect(err).ToNot(HaveOccurred()) + defer sess.Close() + time.Sleep(25 * time.Millisecond) // wait a bit for the session to be queued + + _, err = dial() + Expect(err).To(HaveOccurred()) + // TODO(#1567): use the SERVER_BUSY error code + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway)) + }) + }) }) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index e5afaab63..17def032a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -181,7 +181,7 @@ func newCryptoSetup( logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { - initialSealer, initialOpener, err := newInitialAEAD(connID, perspective) + initialSealer, initialOpener, err := NewInitialAEAD(connID, perspective) if err != nil { return nil, nil, err } diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index 1a6efb086..f3a106c56 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -11,7 +11,8 @@ import ( var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0} -func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) { +// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) { clientSecret, serverSecret := computeSecrets(connID) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index 3670bc458..d85e9d068 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -67,9 +67,9 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { It("seals and opens", func() { connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} - clientSealer, clientOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveClient) + clientSealer, clientOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - serverSealer, serverOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveServer) + serverSealer, serverOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) @@ -85,9 +85,9 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { It("doesn't work if initialized with different connection IDs", func() { c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} - clientSealer, _, err := newInitialAEAD(c1, protocol.PerspectiveClient) + clientSealer, _, err := NewInitialAEAD(c1, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - _, serverOpener, err := newInitialAEAD(c2, protocol.PerspectiveServer) + _, serverOpener, err := NewInitialAEAD(c2, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) @@ -97,9 +97,9 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { It("encrypts und decrypts the header", func() { connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - clientSealer, clientOpener, err := newInitialAEAD(connID, protocol.PerspectiveClient) + clientSealer, clientOpener, err := NewInitialAEAD(connID, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - serverSealer, serverOpener, err := newInitialAEAD(connID, protocol.PerspectiveServer) + serverSealer, serverOpener, err := NewInitialAEAD(connID, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) // the first byte and the last 4 bytes should be encrypted diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 7f4661e81..e6f9493fa 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -53,6 +53,10 @@ const SkipPacketAveragePeriodLength PacketNumber = 500 // MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation const MaxTrackedSkippedPackets = 10 +// MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting. +// If the queue is full, new connection attempts will be rejected. +const MaxAcceptQueueSize = 32 + // CookieExpiryTime is the valid time of a cookie const CookieExpiryTime = 24 * time.Hour diff --git a/server.go b/server.go index 54f0c487f..d7124bc96 100644 --- a/server.go +++ b/server.go @@ -8,10 +8,12 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -88,7 +90,8 @@ type server struct { errorChan chan struct{} closed bool - sessionQueue chan Session + sessionQueue chan Session + sessionQueueLen int32 // to be used as an atomic sessionRunner sessionRunner @@ -164,9 +167,15 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, func (s *server) setup() error { s.sessionRunner = &runner{ - onHandshakeCompleteImpl: func(sess Session) { go func() { s.sessionQueue <- sess }() }, - retireConnectionIDImpl: s.sessionHandler.Retire, - removeConnectionIDImpl: s.sessionHandler.Remove, + onHandshakeCompleteImpl: func(sess Session) { + go func() { + atomic.AddInt32(&s.sessionQueueLen, 1) + s.sessionQueue <- sess // blocks until the session is accepted + atomic.AddInt32(&s.sessionQueueLen, -1) + }() + }, + retireConnectionIDImpl: s.sessionHandler.Retire, + removeConnectionIDImpl: s.sessionHandler.Remove, } cookieGenerator, err := handshake.NewCookieGenerator() if err != nil { @@ -333,7 +342,7 @@ func (s *server) handleInitial(p *receivedPacket) { s.logger.Errorf("Error occurred handling initial packet: %s", err) return } - if sess == nil { // a retry was done + if sess == nil { // a retry was done, or the connection attempt was rejected p.buffer.Release() return } @@ -371,6 +380,11 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con return nil, nil, s.sendRetry(p.remoteAddr, hdr) } + if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { + s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) + return nil, nil, s.sendServerBusy(p.remoteAddr, hdr) + } + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err @@ -460,6 +474,54 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { return nil } +func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { + sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) + if err != nil { + return err + } + packetBuffer := getPacketBuffer() + defer packetBuffer.Release() + buf := bytes.NewBuffer(packetBuffer.Slice[:0]) + + // TODO(#1567): use the SERVER_BUSY error code + ccf := &wire.ConnectionCloseFrame{ErrorCode: qerr.PeerGoingAway} + + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeInitial + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = hdr.DestConnectionID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.PacketNumberLen = protocol.PacketNumberLen4 + replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) + if err := replyHdr.Write(buf, hdr.Version); err != nil { + return err + } + payloadOffset := buf.Len() + + if err := ccf.Write(buf, hdr.Version); err != nil { + return err + } + + raw := buf.Bytes() + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) + raw = raw[0 : buf.Len()+sealer.Overhead()] + + pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) + sealer.EncryptHeader( + raw[pnOffset+4:pnOffset+4+16], + &raw[0], + raw[pnOffset:payloadOffset], + ) + + replyHdr.Log(s.logger) + wire.LogFrame(s.logger, ccf, true) + if _, err := s.conn.WriteTo(raw, remoteAddr); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + return nil +} + func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { defer p.buffer.Release() hdr := p.hdr diff --git a/server_test.go b/server_test.go index 3dfef8150..39618c2c2 100644 --- a/server_test.go +++ b/server_test.go @@ -6,6 +6,7 @@ import ( "errors" "net" "reflect" + "sync" "time" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -317,6 +318,63 @@ var _ = Describe("Server", func() { Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) }) + + It("rejects new connection attempts if the accept queue is full", func() { + serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + + hdr := &wire.Header{ + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + p := &receivedPacket{ + remoteAddr: senderAddr, + hdr: hdr, + data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), + } + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ *handshake.TransportParameters, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(p) + sess.EXPECT().run() + runner.onHandshakeComplete(sess) + return sess, nil + } + + var wg sync.WaitGroup + wg.Add(protocol.MaxAcceptQueueSize) + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + go func() { + defer GinkgoRecover() + defer wg.Done() + serv.handlePacket(insertPacketBuffer(p)) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }() + } + wg.Wait() + serv.handlePacket(insertPacketBuffer(p)) + var reject mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reject)) + Expect(reject.to).To(Equal(senderAddr)) + rejectHdr, err := wire.ParseHeader(bytes.NewReader(reject.data), 0) + Expect(err).ToNot(HaveOccurred()) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + }) }) Context("accepting sessions", func() {