forked from quic-go/quic-go
Move SentPacketHandler.CheckForError into SentPacket
This commit is contained in:
@@ -19,9 +19,6 @@ type SentPacketHandler interface {
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
|
||||
// TODO(lclemente): Remove this now that the logic is simpler
|
||||
CheckForError() error
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
|
||||
@@ -94,6 +94,10 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
||||
return errPacketNumberNotIncreasing
|
||||
}
|
||||
|
||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
||||
return ErrTooManyTrackedSentPackets
|
||||
}
|
||||
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
|
||||
@@ -334,14 +338,6 @@ func (h *sentPacketHandler) SendingAllowed() bool {
|
||||
return !(congestionLimited || maxTrackedLimited)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) CheckForError() error {
|
||||
length := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||
if protocol.PacketNumber(length) > protocol.MaxTrackedSentPackets {
|
||||
return ErrTooManyTrackedSentPackets
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
|
||||
@@ -207,12 +207,14 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
|
||||
Context("DoS mitigation", func() {
|
||||
It("checks the size of the packet history, for unacked packets", func() {
|
||||
for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ {
|
||||
packet := Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1}
|
||||
i := protocol.PacketNumber(1)
|
||||
for ; i <= protocol.MaxTrackedSentPackets; i++ {
|
||||
packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1}
|
||||
err := handler.SentPacket(&packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := handler.CheckForError()
|
||||
packet := Packet{PacketNumber: protocol.PacketNumber(i), Length: 1}
|
||||
err := handler.SentPacket(&packet)
|
||||
Expect(err).To(MatchError(ErrTooManyTrackedSentPackets))
|
||||
})
|
||||
|
||||
|
||||
@@ -567,11 +567,6 @@ func (s *session) closeStreamsWithError(err error) {
|
||||
func (s *session) sendPacket() error {
|
||||
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
|
||||
for {
|
||||
err := s.sentPacketHandler.CheckForError()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.sentPacketHandler.SendingAllowed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -75,11 +75,9 @@ func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacke
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 }
|
||||
func (h *mockSentPacketHandler) CheckForError() error { return nil }
|
||||
|
||||
func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") }
|
||||
func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") }
|
||||
func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited }
|
||||
func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") }
|
||||
func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") }
|
||||
func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited }
|
||||
|
||||
func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
|
||||
h.requestedStopWaiting = true
|
||||
@@ -1331,18 +1329,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
})
|
||||
|
||||
It("errors when the SentPacketHandler has too many packets tracked", func() {
|
||||
streamFrame := frames.StreamFrame{StreamID: 5, Data: []byte("foobar")}
|
||||
for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ {
|
||||
packet := ackhandler.Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1}
|
||||
err := sess.sentPacketHandler.SentPacket(&packet)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
// now sess.sentPacketHandler.CheckForError will return an error
|
||||
err := sess.sendPacket()
|
||||
Expect(err).To(MatchError(ackhandler.ErrTooManyTrackedSentPackets))
|
||||
})
|
||||
|
||||
It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) {
|
||||
// Nothing here should block
|
||||
for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ {
|
||||
|
||||
Reference in New Issue
Block a user