diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index d44a2e0c..a9fc824b 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -19,9 +19,6 @@ type SentPacketHandler interface { GetAlarmTimeout() time.Time OnAlarm() - - // TODO(lclemente): Remove this now that the logic is simpler - CheckForError() error } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index a6013a98..5d0048e3 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -94,6 +94,10 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { return errPacketNumberNotIncreasing } + if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets { + return ErrTooManyTrackedSentPackets + } + for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { h.skippedPackets = append(h.skippedPackets, p) @@ -334,14 +338,6 @@ func (h *sentPacketHandler) SendingAllowed() bool { return !(congestionLimited || maxTrackedLimited) } -func (h *sentPacketHandler) CheckForError() error { - length := len(h.retransmissionQueue) + h.packetHistory.Len() - if protocol.PacketNumber(length) > protocol.MaxTrackedSentPackets { - return ErrTooManyTrackedSentPackets - } - return nil -} - func (h *sentPacketHandler) retransmitOldestTwoPackets() { if p := h.packetHistory.Front(); p != nil { h.queueRTO(p) diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index de2bea35..7ad3d305 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -207,12 +207,14 @@ var _ = Describe("SentPacketHandler", func() { Context("DoS mitigation", func() { It("checks the size of the packet history, for unacked packets", func() { - for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ { - packet := Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1} + i := protocol.PacketNumber(1) + for ; i <= protocol.MaxTrackedSentPackets; i++ { + packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1} err := handler.SentPacket(&packet) Expect(err).ToNot(HaveOccurred()) } - err := handler.CheckForError() + packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1} + err := handler.SentPacket(&packet) Expect(err).To(MatchError(ErrTooManyTrackedSentPackets)) }) diff --git a/session.go b/session.go index 56b53796..11c940d8 100644 --- a/session.go +++ b/session.go @@ -567,11 +567,6 @@ func (s *session) closeStreamsWithError(err error) { func (s *session) sendPacket() error { // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { - err := s.sentPacketHandler.CheckForError() - if err != nil { - return err - } - if !s.sentPacketHandler.SendingAllowed() { return nil } diff --git a/session_test.go b/session_test.go index 43e9287b..a996618a 100644 --- a/session_test.go +++ b/session_test.go @@ -75,11 +75,9 @@ func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacke } func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 } -func (h *mockSentPacketHandler) CheckForError() error { return nil } - -func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } -func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } -func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } +func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } +func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } +func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { h.requestedStopWaiting = true @@ -1331,18 +1329,6 @@ var _ = Describe("Session", func() { }) }) - It("errors when the SentPacketHandler has too many packets tracked", func() { - streamFrame := frames.StreamFrame{StreamID: 5, Data: []byte("foobar")} - for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ { - packet := ackhandler.Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1} - err := sess.sentPacketHandler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) - } - // now sess.sentPacketHandler.CheckForError will return an error - err := sess.sendPacket() - Expect(err).To(MatchError(ackhandler.ErrTooManyTrackedSentPackets)) - }) - It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) { // Nothing here should block for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ {