diff --git a/internal/protocol/params.go b/internal/protocol/params.go index fe88b495..8a7aad07 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -12,7 +12,7 @@ const MaxPacketSizeIPv6 = 1232 const MaxCongestionWindowPackets = 10000 // MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the session. -const MaxUndecryptablePackets = 10 +const MaxUndecryptablePackets = 33 // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // This is the value that Chromium is using @@ -39,6 +39,9 @@ const DefaultMaxIncomingStreams = 100 // DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open const DefaultMaxIncomingUniStreams = 100 +// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. +const MaxServerUnprocessedPackets = 1024 + // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = MaxCongestionWindowPackets @@ -158,3 +161,15 @@ const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key udpate. const KeyUpdateInterval = 100 * 1000 + +// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. +const Max0RTTQueueingDuration = 100 * time.Millisecond + +// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. +const Max0RTTQueues = 32 + +// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. +// When a new session is created, all buffered packets are passed to the session immediately. +// To avoid blocking, this value has to be smaller than MaxSessionUnprocessedPackets. +// To avoid packets being dropped as undecryptable by the session, this value has to be smaller than MaxUndecryptablePackets. +const Max0RTTQueueLen = 32 diff --git a/internal/protocol/params_test.go b/internal/protocol/params_test.go new file mode 100644 index 00000000..b144054a --- /dev/null +++ b/internal/protocol/params_test.go @@ -0,0 +1,13 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Parameters", func() { + It("can queue more packets in the session than in the 0-RTT queue", func() { + Expect(MaxSessionUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) + Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) + }) +}) diff --git a/server.go b/server.go index 07cd79e9..455c92ca 100644 --- a/server.go +++ b/server.go @@ -68,6 +68,7 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator + zeroRTTQueue *zeroRTTQueue sessionHandler packetHandlerManager receivedPackets chan *receivedPacket @@ -175,9 +176,10 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl config: config, tokenGenerator: tokenGenerator, sessionHandler: sessionHandler, + zeroRTTQueue: newZeroRTTQueue(), sessionQueue: make(chan quicSession), errorChan: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, 1000), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), newSession: newSession, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlySessions: acceptEarly, @@ -283,14 +285,14 @@ func (s *baseServer) Addr() net.Addr { } func (s *baseServer) handlePacket(p *receivedPacket) { - s.receivedPackets <- p + select { + case s.receivedPackets <- p: + default: + s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, len(p.data)) + } } -func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { - if len(p.data) < protocol.MinInitialPacketSize { - s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data)) - return false - } +func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet handled */ { // If we're creating a new session, the packet will be passed to the session. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) @@ -302,17 +304,26 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet if !hdr.IsLongHeader { return false } + if hdr.Type == protocol.PacketTypeInitial && len(p.data) < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data)) + return false + } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { go s.sendVersionNegotiationPacket(p, hdr) return false } - if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { - // Drop long header packets. - // There's litte point in sending a Stateless Reset, since the client - // might not have received the token yet. - s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - return false + if hdr.IsLongHeader { + if hdr.Type == protocol.PacketType0RTT { + s.zeroRTTQueue.Enqueue(hdr.DestConnectionID, p) + return true + } else if hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's litte point in sending a Stateless Reset, since the client + // might not have received the token yet. + s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) + return false + } } s.logger.Debugf("<- Received Initial packet.") @@ -382,7 +393,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui connID, hdr.Version, ) - if sess != nil { + if sess == nil { + return nil, nil + } + sess.handlePacket(p) + for { + p := s.zeroRTTQueue.Dequeue(hdr.DestConnectionID) + if p == nil { + break + } sess.handlePacket(p) } return sess, nil diff --git a/server_test.go b/server_test.go index b3ac89db..d5fe1403 100644 --- a/server_test.go +++ b/server_test.go @@ -11,6 +11,7 @@ import ( "runtime/pprof" "strings" "sync" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go/qlog" @@ -389,6 +390,100 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + It("passes queued 0-RTT packets to the session", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + var createdSession bool + sess := NewMockQuicSession(mockCtrl) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} + initialPacket := getInitial(connID) + zeroRTTPacket := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: connID, + Version: protocol.VersionTLS, + }, []byte("foobar")) + sess.EXPECT().Context().Return(context.Background()).MaxTimes(1) + sess.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) + sess.EXPECT().run().MaxTimes(1) + gomock.InOrder( + sess.EXPECT().handlePacket(initialPacket), + sess.EXPECT().handlePacket(zeroRTTPacket), + ) + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ [16]byte, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ qlog.Tracer, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicSession { + createdSession = true + return sess + } + + // Receive the 0-RTT packet first. + Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue()) + // Then receive the Initial packet. + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) + Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue()) + Expect(createdSession).To(BeTrue()) + }) + + It("drops packets if the receive queue is full", func() { + phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() + phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes() + + serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } + acceptSession := make(chan struct{}) + var counter uint32 // to be used as an atomic, so we query it in Eventually + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ [16]byte, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ qlog.Tracer, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicSession { + <-acceptSession + atomic.AddUint32(&counter, 1) + return nil + } + + serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) + var wg sync.WaitGroup + for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) + }() + } + wg.Wait() + + close(acceptSession) + Eventually(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) + Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) + }) + It("only creates a single session for a duplicate Initial", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } var createdSession bool diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go new file mode 100644 index 00000000..f2d6eb4b --- /dev/null +++ b/zero_rtt_queue.go @@ -0,0 +1,71 @@ +package quic + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type zeroRTTQueueEntry struct { + timer *time.Timer + packets []*receivedPacket +} + +type zeroRTTQueue struct { + mutex sync.Mutex + queue map[string]*zeroRTTQueueEntry +} + +func newZeroRTTQueue() *zeroRTTQueue { + return &zeroRTTQueue{queue: make(map[string]*zeroRTTQueueEntry)} +} + +func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) { + h.mutex.Lock() + defer h.mutex.Unlock() + + cid := string(connID) + if _, ok := h.queue[cid]; !ok { + if len(h.queue) >= protocol.Max0RTTQueues { + return + } + h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(protocol.Max0RTTQueueingDuration, func() { h.deleteQueue(connID) })} + } + entry := h.queue[cid] + if len(entry.packets) >= protocol.Max0RTTQueueLen { + return + } + entry.packets = append(entry.packets, p) +} + +func (h *zeroRTTQueue) Dequeue(connID protocol.ConnectionID) *receivedPacket { + h.mutex.Lock() + defer h.mutex.Unlock() + + entry, ok := h.queue[string(connID)] + if !ok { + return nil + } + p := entry.packets[0] + entry.packets = entry.packets[1:] + if len(entry.packets) == 0 { + entry.timer.Stop() + delete(h.queue, string(connID)) + } + return p +} + +func (h *zeroRTTQueue) deleteQueue(connID protocol.ConnectionID) { + h.mutex.Lock() + defer h.mutex.Unlock() + + entry, ok := h.queue[string(connID)] + if !ok { + return + } + for _, p := range entry.packets { + p.buffer.Release() + } + delete(h.queue, string(connID)) +} diff --git a/zero_rtt_queue_test.go b/zero_rtt_queue_test.go new file mode 100644 index 00000000..4b0bb45c --- /dev/null +++ b/zero_rtt_queue_test.go @@ -0,0 +1,113 @@ +package quic + +import ( + "encoding/binary" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("0-RTT queue", func() { + var q *zeroRTTQueue + + BeforeEach(func() { + q = newZeroRTTQueue() + }) + + AfterEach(func() { + // dequeue all packets to make sure the timers are stopped + q.mutex.Lock() + for connID := range q.queue { + for { + q.mutex.Unlock() + p := q.Dequeue(protocol.ConnectionID(connID)) + q.mutex.Lock() + if p != nil { + break + } + } + } + q.mutex.Unlock() + }) + + It("stores a 0-RTT packet", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + p := &receivedPacket{data: []byte("foobar")} + q.Enqueue(connID, p) + Expect(q.Dequeue(connID)).To(Equal(p)) + Expect(q.Dequeue(connID)).To(BeNil()) + }) + + It("returns a nil packet for unknown connection IDs", func() { + Expect(q.Dequeue(protocol.ConnectionID{0x42})).To(BeNil()) + }) + + It("only stores packets for Max0RTTQueues connection", func() { + // fill up the queues + for i := 0; i < protocol.Max0RTTQueues; i++ { + data := make([]byte, 4) + binary.BigEndian.PutUint32(data, uint32(i)) + q.Enqueue(protocol.ConnectionID(data), &receivedPacket{data: data}) + } + // now try to enqueue a packet for another connection ID + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) + Expect(q.Dequeue(connID)).To(BeNil()) + // check that the other queues were all saved + for i := 0; i < protocol.Max0RTTQueues; i++ { + connID := make([]byte, 4) + binary.BigEndian.PutUint32(connID, uint32(i)) + p := q.Dequeue(connID) + Expect(p).ToNot(BeNil()) + Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i)) + } + }) + + It("removes queues when packets are dequeued", func() { + // fill up the queues + for i := 0; i < protocol.Max0RTTQueues; i++ { + data := make([]byte, 4) + binary.BigEndian.PutUint32(data, uint32(i)) + q.Enqueue(protocol.ConnectionID(data), &receivedPacket{data: data}) + } + // now try to enqueue a packet for another connection ID + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) + Expect(q.Dequeue(connID)).To(BeNil()) + // dequeue the packet from the first queue + Expect(q.Dequeue(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil()) + // now it should be possible to queue another packet + q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) + Expect(q.Dequeue(connID)).ToNot(BeNil()) + }) + + It("limits the number of packets it stores for one connection", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + // fill up the queue + for i := 0; i < protocol.Max0RTTQueueLen; i++ { + data := make([]byte, 4) + binary.BigEndian.PutUint32(data, uint32(i)) + q.Enqueue(connID, &receivedPacket{data: data}) + } + // The queue is full now. This packet will be dropped. + q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) + for i := 0; i < protocol.Max0RTTQueueLen; i++ { + p := q.Dequeue(connID) + Expect(p).ToNot(BeNil()) + Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i)) + } + // The queue should now be empty. + Expect(q.Dequeue(connID)).To(BeNil()) + }) + + It("deletes packets if they aren't dequeued after a short while", func() { + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()} + q.Enqueue(connID, p) + time.Sleep(protocol.Max0RTTQueueingDuration * 3 / 2) + Expect(q.Dequeue(connID)).To(BeNil()) + }) +})