From b21822ffc78f70063099bb5e2e88d732cb915e75 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 28 Aug 2020 18:03:45 +0700 Subject: [PATCH] make it easier to test the timer in the zeroRTTQueue --- zero_rtt_queue.go | 14 ++++++++++---- zero_rtt_queue_test.go | 4 +++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go index f2d6eb4b..2fb1e75d 100644 --- a/zero_rtt_queue.go +++ b/zero_rtt_queue.go @@ -13,12 +13,16 @@ type zeroRTTQueueEntry struct { } type zeroRTTQueue struct { - mutex sync.Mutex - queue map[string]*zeroRTTQueueEntry + mutex sync.Mutex + queue map[string]*zeroRTTQueueEntry + queueDuration time.Duration // so we can set it in tests } func newZeroRTTQueue() *zeroRTTQueue { - return &zeroRTTQueue{queue: make(map[string]*zeroRTTQueueEntry)} + return &zeroRTTQueue{ + queue: make(map[string]*zeroRTTQueueEntry), + queueDuration: protocol.Max0RTTQueueingDuration, + } } func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) { @@ -30,7 +34,9 @@ func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) if len(h.queue) >= protocol.Max0RTTQueues { return } - h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(protocol.Max0RTTQueueingDuration, func() { h.deleteQueue(connID) })} + h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(h.queueDuration, func() { + h.deleteQueue(connID) + })} } entry := h.queue[cid] if len(entry.packets) >= protocol.Max0RTTQueueLen { diff --git a/zero_rtt_queue_test.go b/zero_rtt_queue_test.go index 4b0bb45c..898b980e 100644 --- a/zero_rtt_queue_test.go +++ b/zero_rtt_queue_test.go @@ -12,9 +12,11 @@ import ( var _ = Describe("0-RTT queue", func() { var q *zeroRTTQueue + queueDuration := scaleDuration(20 * time.Millisecond) BeforeEach(func() { q = newZeroRTTQueue() + q.queueDuration = queueDuration }) AfterEach(func() { @@ -107,7 +109,7 @@ var _ = Describe("0-RTT queue", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()} q.Enqueue(connID, p) - time.Sleep(protocol.Max0RTTQueueingDuration * 3 / 2) + time.Sleep(queueDuration * 3 / 2) Expect(q.Dequeue(connID)).To(BeNil()) }) })