From daa8d08fba1d5e092193e5ddec64325aa37e7ba6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 30 Sep 2017 22:44:26 +0700 Subject: [PATCH] implement loss detection for handshake packets --- ackhandler/interfaces.go | 1 + ackhandler/sent_packet_handler.go | 47 +++++++++++++++++++++-- ackhandler/sent_packet_handler_test.go | 52 +++++++++++++++++++++++++- congestion/rtt_stats.go | 1 + session.go | 1 + session_test.go | 4 +- 6 files changed, 99 insertions(+), 7 deletions(-) diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index c7a9c56d..6b2810e8 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -12,6 +12,7 @@ type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) error ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error + SetHandshakeComplete() SendingAllowed() bool GetStopWaitingFrame(force bool) *wire.StopWaitingFrame diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 5b7c9df0..c527995f 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -16,8 +16,13 @@ const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // In fraction of an RTT. timeReorderingFraction = 1.0 / 8 + // The default RTT used before an RTT sample is taken. + // Note: This constant is also defined in the congestion package. + defaultInitialRTT = 100 * time.Millisecond // defaultRTOTimeout is the RTO time on new connections defaultRTOTimeout = 500 * time.Millisecond + // Minimum time in the future a tail loss probe alarm may be set for. + minTPLTimeout = 10 * time.Millisecond // Minimum time in the future an RTO alarm may be set for. minRTOTimeout = 200 * time.Millisecond // maxRTOTimeout is the maximum RTO time @@ -56,6 +61,10 @@ type sentPacketHandler struct { congestion congestion.SendAlgorithm rttStats *congestion.RTTStats + handshakeComplete bool + // The number of times the handshake packets have been retransmitted without receiving an ack. + handshakeCount uint32 + // The number of times an RTO has been sent without receiving an ack. rtoCount uint32 @@ -95,6 +104,10 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets } +func (h *sentPacketHandler) SetHandshakeComplete() { + h.handshakeComplete = true +} + func (h *sentPacketHandler) SentPacket(packet *Packet) error { if packet.PacketNumber <= h.lastSentPacketNumber { return errPacketNumberNotIncreasing @@ -247,9 +260,10 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { return } - // TODO(#496): Handle handshake packets separately // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.alarm = time.Now().Add(h.computeHandshakeTimeout()) + } else if !h.lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { @@ -291,9 +305,10 @@ func (h *sentPacketHandler) detectLostPackets() { } func (h *sentPacketHandler) OnAlarm() { - // TODO(#496): Handle handshake packets separately // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.queueHandshakePacketsForRetransmission() + } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection h.detectLostPackets() } else { @@ -312,6 +327,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time { func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { h.bytesInFlight -= packetElement.Value.Length h.rtoCount = 0 + h.handshakeCount = 0 // TODO(#497): h.tlpCount = 0 h.packetHistory.Remove(packetElement) } @@ -372,6 +388,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) { h.congestion.OnRetransmissionTimeout(true) } +func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { + var handshakePackets []*PacketElement + for el := h.packetHistory.Front(); el != nil; el = el.Next() { + if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, el) + } + } + for _, el := range handshakePackets { + h.queuePacketForRetransmission(el) + } +} + func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { packet := &packetElement.Value h.bytesInFlight -= packet.Length @@ -380,6 +408,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) } +func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { + duration := 2 * h.rttStats.SmoothedRTT() + if duration == 0 { + duration = 2 * defaultInitialRTT + } + duration = utils.MaxDuration(duration, minTPLTimeout) + // exponential backoff + // There's an implicit limit to this set by the handshake timeout. + return duration << h.handshakeCount +} + func (h *sentPacketHandler) computeRTOTimeout() time.Duration { rto := h.congestion.RetransmissionDelay() if rto == 0 { diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index c9d79598..7cc87ac6 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -58,13 +58,27 @@ func (m *mockCongestion) OnPacketLost(n protocol.PacketNumber, l protocol.ByteCo } func retransmittablePacket(num protocol.PacketNumber) *Packet { - return &Packet{PacketNumber: num, Length: 1, Frames: []wire.Frame{&wire.PingFrame{}}} + return &Packet{ + PacketNumber: num, + Length: 1, + Frames: []wire.Frame{&wire.PingFrame{}}, + EncryptionLevel: protocol.EncryptionForwardSecure, + } } func nonRetransmittablePacket(num protocol.PacketNumber) *Packet { return &Packet{PacketNumber: num, Length: 1, Frames: []wire.Frame{&wire.AckFrame{}}} } +func handshakePacket(num protocol.PacketNumber) *Packet { + return &Packet{ + PacketNumber: num, + Length: 1, + Frames: []wire.Frame{&wire.PingFrame{}}, + EncryptionLevel: protocol.EncryptionUnencrypted, + } +} + var _ = Describe("SentPacketHandler", func() { var ( handler *sentPacketHandler @@ -74,6 +88,7 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} handler = NewSentPacketHandler(rttStats).(*sentPacketHandler) + handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -807,6 +822,41 @@ var _ = Describe("SentPacketHandler", func() { handler.OnAlarm() Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) + Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) + }) + }) + + Context("retransmission for handshake packets", func() { + BeforeEach(func() { + handler.handshakeComplete = false + }) + + It("detects the handshake timeout", func() { + // send handshake packets: 1, 2, 4 + // send a forward-secure packet: 3 + err := handler.SentPacket(handshakePacket(1)) + Expect(err).ToNot(HaveOccurred()) + err = handler.SentPacket(handshakePacket(2)) + Expect(err).ToNot(HaveOccurred()) + err = handler.SentPacket(retransmittablePacket(3)) + Expect(err).ToNot(HaveOccurred()) + err = handler.SentPacket(handshakePacket(4)) + Expect(err).ToNot(HaveOccurred()) + + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now().Add(time.Hour)) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.lossTime.IsZero()).To(BeTrue()) + Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", handler.computeHandshakeTimeout(), time.Minute)) + + handler.OnAlarm() + p := handler.DequeuePacketForRetransmission() + Expect(p).ToNot(BeNil()) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(2))) + p = handler.DequeuePacketForRetransmission() + Expect(p).ToNot(BeNil()) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(4))) + Expect(handler.packetHistory.Len()).To(Equal(1)) + Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) }) }) diff --git a/congestion/rtt_stats.go b/congestion/rtt_stats.go index 546c1cb9..624957ce 100644 --- a/congestion/rtt_stats.go +++ b/congestion/rtt_stats.go @@ -7,6 +7,7 @@ import ( ) const ( + // Note: This constant is also defined in the ackhandler package. initialRTTus = 100 * 1000 rttAlpha float32 = 0.125 oneMinusAlpha float32 = (1 - rttAlpha) diff --git a/session.go b/session.go index 1f91e3bd..7169265b 100644 --- a/session.go +++ b/session.go @@ -316,6 +316,7 @@ runLoop: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true aeadChanged = nil // prevent this case from ever being selected again + s.sentPacketHandler.SetHandshakeComplete() close(s.handshakeChan) close(s.handshakeCompleteChan) } else { diff --git a/session_test.go b/session_test.go index 1ec63e3c..eb326103 100644 --- a/session_test.go +++ b/session_test.go @@ -82,11 +82,10 @@ func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { h.sentPackets = append(h.sentPackets, packet) return nil } - func (h *mockSentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error { return nil } - +func (h *mockSentPacketHandler) SetHandshakeComplete() {} func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 } func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } @@ -1170,6 +1169,7 @@ var _ = Describe("Session", func() { }) It("retransmits RTO packets", func() { + sess.sentPacketHandler.SetHandshakeComplete() n := protocol.PacketNumber(10) sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} // We simulate consistently low RTTs, so that the test works faster