From 4822def788bacd5eb3ad208543089281bf1b2f30 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Wed, 18 May 2016 23:21:33 +0200 Subject: [PATCH] implement timeout-based retransmission (RTO) fixes #56 --- ackhandler/interfaces.go | 4 + ackhandler/sent_packet_handler.go | 32 ++++++++ ackhandler/sent_packet_handler_test.go | 104 +++++++++++++++++++++++-- packet_packer_test.go | 2 + session.go | 10 ++- session_test.go | 26 +++++++ 6 files changed, 170 insertions(+), 8 deletions(-) diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index b7715979..f9e5e6f9 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -1,6 +1,8 @@ package ackhandler import ( + "time" + "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" ) @@ -17,6 +19,8 @@ type SentPacketHandler interface { GetLargestObserved() protocol.PacketNumber AllowsSending() bool + + TimeToFirstRTO() time.Duration } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index c3ea7e62..3f26fd8d 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -251,6 +251,8 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) error { } func (h *sentPacketHandler) HasPacketForRetransmission() bool { + h.queuePacketsRTO() + if len(h.retransmissionQueue) > 0 { return true } @@ -288,3 +290,33 @@ func (h *sentPacketHandler) getRTO() time.Duration { } return utils.MaxDuration(rto, protocol.MinRetransmissionTime) } + +func (h *sentPacketHandler) queuePacketsRTO() { + queued := false + now := time.Now() + for _, p := range h.packetHistory { + if p == nil || p.Retransmitted || p.rtoTime.After(now) { + continue + } + h.queuePacketForRetransmission(p) + queued = true + } + if queued { + h.congestion.OnRetransmissionTimeout(true) + } +} + +func (h *sentPacketHandler) TimeToFirstRTO() time.Duration { + now := time.Now() + min := utils.InfDuration + for _, p := range h.packetHistory { + if p == nil || p.Retransmitted { + continue + } + if now.After(p.rtoTime) { + return 0 + } + min = utils.MinDuration(min, p.rtoTime.Sub(now)) + } + return min +} diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index c006bdf2..8703d9fc 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -12,9 +12,10 @@ import ( ) type mockCongestion struct { - nCalls int - argsOnPacketSent []interface{} - argsOnCongestionEvent []interface{} + nCalls int + argsOnPacketSent []interface{} + argsOnCongestionEvent []interface{} + onRetransmissionTimeout bool } func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { @@ -37,15 +38,19 @@ func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protoc m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets} } -func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") } -func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { panic("not implemented") } -func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") } -func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } +func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { + m.nCalls++ + m.onRetransmissionTimeout = true +} func (m *mockCongestion) RetransmissionDelay() time.Duration { return protocol.DefaultRetransmissionTime } +func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") } +func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") } +func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } + var _ = Describe("SentPacketHandler", func() { var ( handler *sentPacketHandler @@ -556,6 +561,14 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).NotTo(HaveOccurred()) Expect(handler.AllowsSending()).To(BeFalse()) }) + + It("should call OnRetransmissionTimeout", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + handler.queuePacketsRTO() + Expect(cong.onRetransmissionTimeout).To(BeTrue()) + }) }) Context("calculating RTO", func() { @@ -583,4 +596,81 @@ var _ = Describe("SentPacketHandler", func() { Expect(utils.AbsDuration(expected.Sub(val))).To(BeNumerically("<", time.Millisecond)) }) }) + + Context("RTO retransmission", func() { + Context("calculating the time to first RTO", func() { + It("defaults to inf", func() { + Expect(handler.TimeToFirstRTO()).To(Equal(utils.InfDuration)) + }) + + It("returns time to RTO", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.TimeToFirstRTO()).To(BeNumerically("~", protocol.DefaultRetransmissionTime, time.Millisecond)) + }) + + It("returns 0 when RTOs are required", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + Expect(handler.TimeToFirstRTO()).To(BeZero()) + }) + + It("ignores nil packets", func() { + handler.packetHistory[1] = nil + handler.queuePacketsRTO() + Expect(handler.TimeToFirstRTO()).To(Equal(utils.InfDuration)) + }) + }) + + Context("queuing packets due to RTO", func() { + It("does nothing if not required", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + handler.queuePacketsRTO() + Expect(handler.retransmissionQueue).To(BeEmpty()) + }) + + It("queues a packet if RTO expired", func() { + p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1} + err := handler.SentPacket(p) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + handler.queuePacketsRTO() + Expect(handler.retransmissionQueue).To(HaveLen(1)) + Expect(handler.retransmissionQueue[0]).To(Equal(p)) + }) + + It("does not queue retransmittedpackets", func() { + p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1, Retransmitted: true} + err := handler.SentPacket(p) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + handler.queuePacketsRTO() + Expect(handler.retransmissionQueue).To(BeEmpty()) + }) + + It("ignores nil packets", func() { + handler.packetHistory[1] = nil + handler.queuePacketsRTO() + Expect(handler.retransmissionQueue).To(BeEmpty()) + }) + }) + + It("works with HasPacketForRetransmission", func() { + p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1} + err := handler.SentPacket(p) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + Expect(handler.HasPacketForRetransmission()).To(BeTrue()) + }) + + It("works with DequeuePacketForRetransmission", func() { + p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1} + err := handler.SentPacket(p) + Expect(err).NotTo(HaveOccurred()) + handler.packetHistory[1].rtoTime = time.Now().Add(-time.Second) + Expect(handler.DequeuePacketForRetransmission()).To(Equal(p)) + }) + }) }) diff --git a/packet_packer_test.go b/packet_packer_test.go index 798bbf2b..7171e72a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "time" "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/crypto" @@ -21,6 +22,7 @@ func (h *mockSentPacketHandler) HasPacketForRetransmission() bool func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount { return 0 } func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber { return 1 } func (h *mockSentPacketHandler) AllowsSending() bool { panic("not implemented") } +func (h *mockSentPacketHandler) TimeToFirstRTO() time.Duration { panic("not implemented") } func newMockSentPacketHandler() ackhandler.SentPacketHandler { return &mockSentPacketHandler{} diff --git a/session.go b/session.go index 40bfd008..3ace6ede 100644 --- a/session.go +++ b/session.go @@ -136,12 +136,18 @@ func (s *Session) run() { default: } - // receive at a nil channel blocks forever + // Note: receive at a nil channel blocks forever + var smallPacketSendTimer <-chan time.Time if !s.smallPacketDelayedOccurranceTime.IsZero() { smallPacketSendTimer = time.After(time.Now().Sub(s.smallPacketDelayedOccurranceTime)) } + var rtoTimer <-chan time.Time + if d := s.sentPacketHandler.TimeToFirstRTO(); d != utils.InfDuration { + rtoTimer = time.After(d) + } + var err error select { case <-s.closeChan: @@ -157,6 +163,8 @@ func (s *Session) run() { err = s.maybeSendPacket() case <-smallPacketSendTimer: err = s.sendPacket() + case <-rtoTimer: + err = s.sendPacket() case <-s.aeadChanged: s.tryDecryptingQueuedPackets() case <-time.After(s.connectionParametersManager.GetIdleConnectionStateLifetime()): diff --git a/session_test.go b/session_test.go index 43166a7e..c4543585 100644 --- a/session_test.go +++ b/session_test.go @@ -564,4 +564,30 @@ var _ = Describe("Session", func() { } close(done) }, 0.5) + + It("retransmits RTO packets", func() { + // We simulate consistently low RTTs, so that the test works faster + n := protocol.PacketNumber(10) + for p := protocol.PacketNumber(1); p < n; p++ { + err := session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: p, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + time.Sleep(time.Microsecond) + err = session.sentPacketHandler.ReceivedAck(&frames.AckFrame{LargestObserved: p}) + Expect(err).NotTo(HaveOccurred()) + } + // Now, we send a single packet, and expect that it was retransmitted later + go session.run() + Expect(conn.written).To(BeEmpty()) + err := session.sentPacketHandler.SentPacket(&ackhandler.Packet{ + PacketNumber: n, + Length: 1, + Frames: []frames.Frame{&frames.StreamFrame{ + Data: bytes.Repeat([]byte{'a'}, int(protocol.SmallPacketPayloadSizeThreshold)+1), + }}, + }) + session.packer.lastPacketNumber = n + Expect(err).NotTo(HaveOccurred()) + session.scheduleSending() + Eventually(func() bool { return len(conn.written) > 0 }).Should(BeTrue()) + }) })