diff --git a/internal/protocol/params.go b/internal/protocol/params.go index f5b56eaa0..4f52709f7 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -15,7 +15,7 @@ const MaxPacketSizeIPv6 = 1232 const MaxCongestionWindowPackets = 10000 // MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the session. -const MaxUndecryptablePackets = 33 +const MaxUndecryptablePackets = 32 // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // This is the value that Chromium is using @@ -192,4 +192,4 @@ const Max0RTTQueues = 32 // When a new session is created, all buffered packets are passed to the session immediately. // To avoid blocking, this value has to be smaller than MaxSessionUnprocessedPackets. // To avoid packets being dropped as undecryptable by the session, this value has to be smaller than MaxUndecryptablePackets. -const Max0RTTQueueLen = 32 +const Max0RTTQueueLen = 31 diff --git a/server.go b/server.go index c1ee6237a..40ff7c8b6 100644 --- a/server.go +++ b/server.go @@ -463,13 +463,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } sess.handlePacket(p) - for { - p := s.zeroRTTQueue.Dequeue(hdr.DestConnectionID) - if p == nil { - break - } - sess.handlePacket(p) - } + s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess) return nil } diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go index 2fb1e75dd..b0173f9db 100644 --- a/zero_rtt_queue.go +++ b/zero_rtt_queue.go @@ -45,21 +45,23 @@ func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) entry.packets = append(entry.packets, p) } -func (h *zeroRTTQueue) Dequeue(connID protocol.ConnectionID) *receivedPacket { +func (h *zeroRTTQueue) DequeueToSession(connID protocol.ConnectionID, sess packetHandler) { h.mutex.Lock() defer h.mutex.Unlock() + h.dequeueToSession(connID, sess) +} + +func (h *zeroRTTQueue) dequeueToSession(connID protocol.ConnectionID, sess packetHandler) { entry, ok := h.queue[string(connID)] if !ok { - return nil + return } - p := entry.packets[0] - entry.packets = entry.packets[1:] - if len(entry.packets) == 0 { - entry.timer.Stop() - delete(h.queue, string(connID)) + entry.timer.Stop() + for _, p := range entry.packets { + sess.handlePacket(p) } - return p + delete(h.queue, string(connID)) } func (h *zeroRTTQueue) deleteQueue(connID protocol.ConnectionID) { diff --git a/zero_rtt_queue_test.go b/zero_rtt_queue_test.go index 62048d9b8..b4d379967 100644 --- a/zero_rtt_queue_test.go +++ b/zero_rtt_queue_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -14,6 +15,16 @@ var _ = Describe("0-RTT queue", func() { var q *zeroRTTQueue queueDuration := scaleDuration(50 * time.Millisecond) + getQueuedPackets := func(connID protocol.ConnectionID) []*receivedPacket { + var packets []*receivedPacket + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + packets = append(packets, p) + }).AnyTimes() + q.DequeueToSession(connID, sess) + return packets + } + BeforeEach(func() { q = newZeroRTTQueue() q.queueDuration = queueDuration @@ -23,14 +34,9 @@ var _ = Describe("0-RTT queue", func() { // dequeue all packets to make sure the timers are stopped q.mutex.Lock() for connID := range q.queue { - for { - q.mutex.Unlock() - p := q.Dequeue(protocol.ConnectionID(connID)) - q.mutex.Lock() - if p != nil { - break - } - } + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()).AnyTimes() + q.dequeueToSession(protocol.ConnectionID(connID), sess) } q.mutex.Unlock() }) @@ -39,32 +45,34 @@ var _ = Describe("0-RTT queue", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} p := &receivedPacket{data: []byte("foobar")} q.Enqueue(connID, p) - Expect(q.Dequeue(connID)).To(Equal(p)) - Expect(q.Dequeue(connID)).To(BeNil()) + queuedPackets := getQueuedPackets(connID) + Expect(queuedPackets).To(Equal([]*receivedPacket{p})) }) - It("returns a nil packet for unknown connection IDs", func() { - Expect(q.Dequeue(protocol.ConnectionID{0x42})).To(BeNil()) + It("doesn't dequeue for unknown connection IDs", func() { + Expect(getQueuedPackets(protocol.ConnectionID{0x42})).To(BeEmpty()) }) It("only stores packets for Max0RTTQueues connection", func() { + getConnID := func(i int) protocol.ConnectionID { + connID := make([]byte, 4) + binary.BigEndian.PutUint32(connID, uint32(i)) + return connID + } + // fill up the queues for i := 0; i < protocol.Max0RTTQueues; i++ { - data := make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(i)) - q.Enqueue(protocol.ConnectionID(data), &receivedPacket{data: data}) + connID := getConnID(i) + q.Enqueue(connID, &receivedPacket{data: []byte(connID)}) } // now try to enqueue a packet for another connection ID connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - Expect(q.Dequeue(connID)).To(BeNil()) // check that the other queues were all saved for i := 0; i < protocol.Max0RTTQueues; i++ { - connID := make([]byte, 4) - binary.BigEndian.PutUint32(connID, uint32(i)) - p := q.Dequeue(connID) - Expect(p).ToNot(BeNil()) - Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i)) + queuedPackets := getQueuedPackets(getConnID(i)) + Expect(queuedPackets).To(HaveLen(1)) + Expect(binary.BigEndian.Uint32(queuedPackets[0].data)).To(BeEquivalentTo(i)) } }) @@ -78,12 +86,12 @@ var _ = Describe("0-RTT queue", func() { // now try to enqueue a packet for another connection ID connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - Expect(q.Dequeue(connID)).To(BeNil()) + Expect(getQueuedPackets(connID)).To(BeEmpty()) // dequeue the packet from the first queue - Expect(q.Dequeue(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil()) + Expect(getQueuedPackets(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil()) // now it should be possible to queue another packet q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - Expect(q.Dequeue(connID)).ToNot(BeNil()) + Expect(getQueuedPackets(connID)).ToNot(BeNil()) }) It("limits the number of packets it stores for one connection", func() { @@ -96,13 +104,11 @@ var _ = Describe("0-RTT queue", func() { } // The queue is full now. This packet will be dropped. q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - for i := 0; i < protocol.Max0RTTQueueLen; i++ { - p := q.Dequeue(connID) - Expect(p).ToNot(BeNil()) + queuedPackets := getQueuedPackets(connID) + Expect(queuedPackets).To(HaveLen(protocol.Max0RTTQueueLen)) + for i, p := range queuedPackets { Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i)) } - // The queue should now be empty. - Expect(q.Dequeue(connID)).To(BeNil()) }) It("deletes packets if they aren't dequeued after a short while", func() { @@ -110,6 +116,6 @@ var _ = Describe("0-RTT queue", func() { p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()} q.Enqueue(connID, p) time.Sleep(queueDuration * 3 / 2) - Expect(q.Dequeue(connID)).To(BeNil()) + Expect(getQueuedPackets(connID)).To(BeNil()) }) })