From fbeba1f73a7c8484e761d82708495fc3fc6c1638 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 1 Feb 2018 16:24:15 +0800 Subject: [PATCH] make sure that at least every 20th ACK is retransmittable This is important because we need to make sure that we can remove old ACK ranges from the received packet history. The logic we implemented before was not correct, since we only made sure that every 20th packet would be retransmittable, but we didn't have any guarantee that this packet would also contain an ACK frame. --- ackhandler/interfaces.go | 1 - ackhandler/sent_packet_handler.go | 9 -- ackhandler/sent_packet_handler_test.go | 24 ----- .../mocks/ackhandler/sent_packet_handler.go | 12 --- internal/protocol/server_parameters.go | 4 +- packet_packer.go | 34 +++---- packet_packer_test.go | 99 ++++++++++--------- session.go | 4 - session_test.go | 27 ----- 9 files changed, 73 insertions(+), 141 deletions(-) diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 248eb76e6..09b2c0172 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -31,7 +31,6 @@ type SentPacketHandler interface { GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - ShouldSendRetransmittablePacket() bool DequeuePacketForRetransmission() (packet *Packet) GetLeastUnacked() protocol.PacketNumber diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 6c0fd0195..a4cd7f34d 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -38,8 +38,6 @@ type sentPacketHandler struct { nextPacketSendTime time.Time skippedPackets []protocol.PacketNumber - numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet - largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived @@ -96,10 +94,6 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { return h.largestAcked + 1 } -func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { - return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets -} - func (h *sentPacketHandler) SetHandshakeComplete() { var queue []*Packet for _, packet := range h.retransmissionQueue { @@ -142,9 +136,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { packet.largestAcked = largestAcked h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) - h.numNonRetransmittablePackets = 0 - } else { - h.numNonRetransmittablePackets++ } h.congestion.OnPacketSent( diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 6601d38df..a2fd0f651 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -186,30 +186,6 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("forcing retransmittable packets", func() { - It("says that every 20th packet should be retransmittable", func() { - // send 19 non-retransmittable packets - for i := 1; i <= protocol.MaxNonRetransmittablePackets; i++ { - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - err := handler.SentPacket(nonRetransmittablePacket(protocol.PacketNumber(i))) - Expect(err).ToNot(HaveOccurred()) - } - Expect(handler.ShouldSendRetransmittablePacket()).To(BeTrue()) - }) - - It("resets the counter when a retransmittable packet is sent", func() { - // send 19 non-retransmittable packets - for i := 1; i <= protocol.MaxNonRetransmittablePackets; i++ { - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - err := handler.SentPacket(nonRetransmittablePacket(protocol.PacketNumber(i))) - Expect(err).ToNot(HaveOccurred()) - } - err := handler.SentPacket(retransmittablePacket(20)) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - }) - }) - Context("DoS mitigation", func() { It("checks the size of the packet history, for unacked packets", func() { i := protocol.PacketNumber(1) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index bad045ed9..54ff26f10 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -165,18 +165,6 @@ func (mr *MockSentPacketHandlerMockRecorder) ShouldSendNumPackets() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendNumPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendNumPackets)) } -// ShouldSendRetransmittablePacket mocks base method -func (m *MockSentPacketHandler) ShouldSendRetransmittablePacket() bool { - ret := m.ctrl.Call(m, "ShouldSendRetransmittablePacket") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ShouldSendRetransmittablePacket indicates an expected call of ShouldSendRetransmittablePacket -func (mr *MockSentPacketHandlerMockRecorder) ShouldSendRetransmittablePacket() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendRetransmittablePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendRetransmittablePacket)) -} - // TimeUntilSend mocks base method func (m *MockSentPacketHandler) TimeUntilSend() time.Time { ret := m.ctrl.Call(m, "TimeUntilSend") diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index e0aad6cff..61e5a2dfc 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -90,8 +90,8 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow -// MaxNonRetransmittablePackets is the maximum number of non-retransmittable packets that we send in a row -const MaxNonRetransmittablePackets = 19 +// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row +const MaxNonRetransmittableAcks = 19 // RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for const RetransmittablePacketsBeforeAck = 10 diff --git a/packet_packer.go b/packet_packer.go index 74e46e35c..ead37404c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -37,12 +37,12 @@ type packetPacker struct { controlFrameMutex sync.Mutex controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber - omitConnectionID bool - hasSentPacket bool // has the packetPacker already sent a packet - makeNextPacketRetransmittable bool + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool + hasSentPacket bool // has the packetPacker already sent a packet + numNonRetransmittableAcks int } func newPacketPacker(connectionID protocol.ConnectionID, @@ -169,14 +169,18 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payloadFrames) == 1 && p.stopWaiting != nil { return nil, nil } - // check if this packet only contains an ACK and / or STOP_WAITING - if !ackhandler.HasRetransmittableFrames(payloadFrames) { - if p.makeNextPacketRetransmittable { - payloadFrames = append(payloadFrames, &wire.PingFrame{}) - p.makeNextPacketRetransmittable = false + if p.ackFrame != nil { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if len(payloadFrames) == 1 || (p.stopWaiting != nil && len(payloadFrames) == 2) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + payloadFrames = append(payloadFrames, &wire.PingFrame{}) + p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ + } + } else { + p.numNonRetransmittableAcks = 0 } - } else { // this packet already contains a retransmittable frame. No need to send a PING - p.makeNextPacketRetransmittable = false } p.stopWaiting = nil p.ackFrame = nil @@ -392,7 +396,3 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } - -func (p *packetPacker) MakeNextPacketRetransmittable() { - p.makeNextPacketRetransmittable = true -} diff --git a/packet_packer_test.go b/packet_packer_test.go index 936d25546..bc0d1b587 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -378,53 +378,62 @@ var _ = Describe("Packet packer", func() { Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) }) - It("adds a PING frame when it's supposed to send a retransmittable packet", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.QueueControlFrame(&wire.AckFrame{}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - packer.MakeNextPacketRetransmittable() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) - // make sure the next packet doesn't contain another PING - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - }) + Context("making ACK packets retransmittable", func() { + sendMaxNumNonRetransmittableAcks := func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(protocol.MaxNonRetransmittableAcks) + for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + } + } - It("waits until there's something to send before adding a PING frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.MakeNextPacketRetransmittable() - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) - }) + It("adds a PING frame when it's supposed to send a retransmittable packet", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + // make sure the next packet doesn't contain another PING + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + }) - It("doesn't send a PING if it already sent another retransmittable frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.MakeNextPacketRetransmittable() - packer.QueueControlFrame(&wire.MaxDataFrame{}) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + It("waits until there's something to send before adding a PING frame", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + }) + + It("doesn't send a PING if it already sent another retransmittable frame", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + packer.QueueControlFrame(&wire.MaxDataFrame{}) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) + }) }) Context("STREAM frame handling", func() { diff --git a/session.go b/session.go index 2127b2954..f3a81dc0a 100644 --- a/session.go +++ b/session.go @@ -841,10 +841,6 @@ func (s *session) sendPacket() (bool, error) { s.packer.QueueControlFrame(swf) } } - // add a retransmittable frame - if s.sentPacketHandler.ShouldSendRetransmittablePacket() { - s.packer.MakeNextPacketRetransmittable() - } packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err diff --git a/session_test.go b/session_test.go index 5030ca89a..67670a6a1 100644 --- a/session_test.go +++ b/session_test.go @@ -628,24 +628,6 @@ var _ = Describe("Session", func() { Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) - It("sends a retransmittable packet when required by the SentPacketHandler", func() { - ack := &wire.AckFrame{LargestAcked: 1000} - sess.packer.QueueControlFrame(ack) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket().Return(true) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames).To(ContainElement(ack)) - }) - sess.sentPacketHandler = sph - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) - }) - It("adds a MAX_DATA frames", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) @@ -654,7 +636,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.MaxDataFrame{ByteOffset: 0x1337}, @@ -674,7 +655,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) }) @@ -692,7 +672,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.BlockedFrame{Offset: 1337}, @@ -721,7 +700,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().GetStopWaitingFrame(gomock.Any()) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { sentPacket = p }) @@ -749,7 +727,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetAlarmTimeout().AnyTimes() sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() - sph.EXPECT().ShouldSendRetransmittablePacket().AnyTimes() sess.sentPacketHandler = sph sess.packer.hasSentPacket = true streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -981,7 +958,6 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{swf, f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -1004,7 +980,6 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -1038,7 +1013,6 @@ var _ = Describe("Session", func() { sph.EXPECT().DequeuePacketForRetransmission().Return(p2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(HaveLen(3)) }) @@ -1116,7 +1090,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().Times(2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(gomock.Any()) - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{}))