Move SentPacketHandler.CheckForError into SentPacket

This commit is contained in:
Lucas Clemente
2017-03-31 21:07:36 +02:00
parent 1c5380c49b
commit 612fa16a43
5 changed files with 12 additions and 36 deletions

View File

@@ -19,9 +19,6 @@ type SentPacketHandler interface {
GetAlarmTimeout() time.Time
OnAlarm()
// TODO(lclemente): Remove this now that the logic is simpler
CheckForError() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets

View File

@@ -94,6 +94,10 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
return errPacketNumberNotIncreasing
}
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
}
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
@@ -334,14 +338,6 @@ func (h *sentPacketHandler) SendingAllowed() bool {
return !(congestionLimited || maxTrackedLimited)
}
func (h *sentPacketHandler) CheckForError() error {
length := len(h.retransmissionQueue) + h.packetHistory.Len()
if protocol.PacketNumber(length) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
}
return nil
}
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)

View File

@@ -207,12 +207,14 @@ var _ = Describe("SentPacketHandler", func() {
Context("DoS mitigation", func() {
It("checks the size of the packet history, for unacked packets", func() {
for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ {
packet := Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1}
i := protocol.PacketNumber(1)
for ; i <= protocol.MaxTrackedSentPackets; i++ {
packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1}
err := handler.SentPacket(&packet)
Expect(err).ToNot(HaveOccurred())
}
err := handler.CheckForError()
packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1}
err := handler.SentPacket(&packet)
Expect(err).To(MatchError(ErrTooManyTrackedSentPackets))
})

View File

@@ -567,11 +567,6 @@ func (s *session) closeStreamsWithError(err error) {
func (s *session) sendPacket() error {
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
for {
err := s.sentPacketHandler.CheckForError()
if err != nil {
return err
}
if !s.sentPacketHandler.SendingAllowed() {
return nil
}

View File

@@ -75,11 +75,9 @@ func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacke
}
func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 }
func (h *mockSentPacketHandler) CheckForError() error { return nil }
func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") }
func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") }
func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited }
func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") }
func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") }
func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited }
func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
h.requestedStopWaiting = true
@@ -1331,18 +1329,6 @@ var _ = Describe("Session", func() {
})
})
It("errors when the SentPacketHandler has too many packets tracked", func() {
streamFrame := frames.StreamFrame{StreamID: 5, Data: []byte("foobar")}
for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ {
packet := ackhandler.Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1}
err := sess.sentPacketHandler.SentPacket(&packet)
Expect(err).ToNot(HaveOccurred())
}
// now sess.sentPacketHandler.CheckForError will return an error
err := sess.sendPacket()
Expect(err).To(MatchError(ackhandler.ErrTooManyTrackedSentPackets))
})
It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) {
// Nothing here should block
for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ {