diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 248eb76e6..09b2c0172 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -31,7 +31,6 @@ type SentPacketHandler interface { GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - ShouldSendRetransmittablePacket() bool DequeuePacketForRetransmission() (packet *Packet) GetLeastUnacked() protocol.PacketNumber diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 6c0fd0195..a4cd7f34d 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -38,8 +38,6 @@ type sentPacketHandler struct { nextPacketSendTime time.Time skippedPackets []protocol.PacketNumber - numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet - largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived @@ -96,10 +94,6 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { return h.largestAcked + 1 } -func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { - return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets -} - func (h *sentPacketHandler) SetHandshakeComplete() { var queue []*Packet for _, packet := range h.retransmissionQueue { @@ -142,9 +136,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { packet.largestAcked = largestAcked h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) - h.numNonRetransmittablePackets = 0 - } else { - h.numNonRetransmittablePackets++ } h.congestion.OnPacketSent( diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 6601d38df..a2fd0f651 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -186,30 +186,6 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("forcing retransmittable packets", func() { - It("says that every 20th packet should be retransmittable", func() { - // send 19 non-retransmittable packets - for i := 1; i <= protocol.MaxNonRetransmittablePackets; i++ { - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - err := handler.SentPacket(nonRetransmittablePacket(protocol.PacketNumber(i))) - Expect(err).ToNot(HaveOccurred()) - } - Expect(handler.ShouldSendRetransmittablePacket()).To(BeTrue()) - }) - - It("resets the counter when a retransmittable packet is sent", func() { - // send 19 non-retransmittable packets - for i := 1; i <= protocol.MaxNonRetransmittablePackets; i++ { - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - err := handler.SentPacket(nonRetransmittablePacket(protocol.PacketNumber(i))) - Expect(err).ToNot(HaveOccurred()) - } - err := handler.SentPacket(retransmittablePacket(20)) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.ShouldSendRetransmittablePacket()).To(BeFalse()) - }) - }) - Context("DoS mitigation", func() { It("checks the size of the packet history, for unacked packets", func() { i := protocol.PacketNumber(1) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index bad045ed9..54ff26f10 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -165,18 +165,6 @@ func (mr *MockSentPacketHandlerMockRecorder) ShouldSendNumPackets() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendNumPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendNumPackets)) } -// ShouldSendRetransmittablePacket mocks base method -func (m *MockSentPacketHandler) ShouldSendRetransmittablePacket() bool { - ret := m.ctrl.Call(m, "ShouldSendRetransmittablePacket") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ShouldSendRetransmittablePacket indicates an expected call of ShouldSendRetransmittablePacket -func (mr *MockSentPacketHandlerMockRecorder) ShouldSendRetransmittablePacket() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendRetransmittablePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendRetransmittablePacket)) -} - // TimeUntilSend mocks base method func (m *MockSentPacketHandler) TimeUntilSend() time.Time { ret := m.ctrl.Call(m, "TimeUntilSend") diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index e0aad6cff..61e5a2dfc 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -90,8 +90,8 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow -// MaxNonRetransmittablePackets is the maximum number of non-retransmittable packets that we send in a row -const MaxNonRetransmittablePackets = 19 +// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row +const MaxNonRetransmittableAcks = 19 // RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for const RetransmittablePacketsBeforeAck = 10 diff --git a/packet_packer.go b/packet_packer.go index 74e46e35c..ead37404c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -37,12 +37,12 @@ type packetPacker struct { controlFrameMutex sync.Mutex controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber - omitConnectionID bool - hasSentPacket bool // has the packetPacker already sent a packet - makeNextPacketRetransmittable bool + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool + hasSentPacket bool // has the packetPacker already sent a packet + numNonRetransmittableAcks int } func newPacketPacker(connectionID protocol.ConnectionID, @@ -169,14 +169,18 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payloadFrames) == 1 && p.stopWaiting != nil { return nil, nil } - // check if this packet only contains an ACK and / or STOP_WAITING - if !ackhandler.HasRetransmittableFrames(payloadFrames) { - if p.makeNextPacketRetransmittable { - payloadFrames = append(payloadFrames, &wire.PingFrame{}) - p.makeNextPacketRetransmittable = false + if p.ackFrame != nil { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if len(payloadFrames) == 1 || (p.stopWaiting != nil && len(payloadFrames) == 2) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + payloadFrames = append(payloadFrames, &wire.PingFrame{}) + p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ + } + } else { + p.numNonRetransmittableAcks = 0 } - } else { // this packet already contains a retransmittable frame. No need to send a PING - p.makeNextPacketRetransmittable = false } p.stopWaiting = nil p.ackFrame = nil @@ -392,7 +396,3 @@ func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } - -func (p *packetPacker) MakeNextPacketRetransmittable() { - p.makeNextPacketRetransmittable = true -} diff --git a/packet_packer_test.go b/packet_packer_test.go index 936d25546..bc0d1b587 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -378,53 +378,62 @@ var _ = Describe("Packet packer", func() { Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) }) - It("adds a PING frame when it's supposed to send a retransmittable packet", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.QueueControlFrame(&wire.AckFrame{}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - packer.MakeNextPacketRetransmittable() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) - // make sure the next packet doesn't contain another PING - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - }) + Context("making ACK packets retransmittable", func() { + sendMaxNumNonRetransmittableAcks := func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(protocol.MaxNonRetransmittableAcks) + for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + } + } - It("waits until there's something to send before adding a PING frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.MakeNextPacketRetransmittable() - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) - }) + It("adds a PING frame when it's supposed to send a retransmittable packet", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + // make sure the next packet doesn't contain another PING + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + }) - It("doesn't send a PING if it already sent another retransmittable frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.MakeNextPacketRetransmittable() - packer.QueueControlFrame(&wire.MaxDataFrame{}) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - packer.QueueControlFrame(&wire.AckFrame{}) - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + It("waits until there's something to send before adding a PING frame", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + }) + + It("doesn't send a PING if it already sent another retransmittable frame", func() { + sendMaxNumNonRetransmittableAcks() + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + packer.QueueControlFrame(&wire.MaxDataFrame{}) + packer.QueueControlFrame(&wire.AckFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) + }) }) Context("STREAM frame handling", func() { diff --git a/session.go b/session.go index 2127b2954..f3a81dc0a 100644 --- a/session.go +++ b/session.go @@ -841,10 +841,6 @@ func (s *session) sendPacket() (bool, error) { s.packer.QueueControlFrame(swf) } } - // add a retransmittable frame - if s.sentPacketHandler.ShouldSendRetransmittablePacket() { - s.packer.MakeNextPacketRetransmittable() - } packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err diff --git a/session_test.go b/session_test.go index 5030ca89a..67670a6a1 100644 --- a/session_test.go +++ b/session_test.go @@ -628,24 +628,6 @@ var _ = Describe("Session", func() { Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) - It("sends a retransmittable packet when required by the SentPacketHandler", func() { - ack := &wire.AckFrame{LargestAcked: 1000} - sess.packer.QueueControlFrame(ack) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket().Return(true) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames).To(ContainElement(ack)) - }) - sess.sentPacketHandler = sph - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) - }) - It("adds a MAX_DATA frames", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().GetWindowUpdate().Return(protocol.ByteCount(0x1337)) @@ -654,7 +636,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.MaxDataFrame{ByteOffset: 0x1337}, @@ -674,7 +655,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) }) @@ -692,7 +672,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.BlockedFrame{Offset: 1337}, @@ -721,7 +700,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().GetStopWaitingFrame(gomock.Any()) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { sentPacket = p }) @@ -749,7 +727,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetAlarmTimeout().AnyTimes() sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() - sph.EXPECT().ShouldSendRetransmittablePacket().AnyTimes() sess.sentPacketHandler = sph sess.packer.hasSentPacket = true streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -981,7 +958,6 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{swf, f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -1004,7 +980,6 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) @@ -1038,7 +1013,6 @@ var _ = Describe("Session", func() { sph.EXPECT().DequeuePacketForRetransmission().Return(p2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(HaveLen(3)) }) @@ -1116,7 +1090,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().Times(2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(gomock.Any()) - sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{}))