diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 09b2c0172..9ec702165 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -10,7 +10,7 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(packet *Packet) error + SentPacket(packet *Packet) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error SetHandshakeComplete() diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 11fa2f911..786d42f50 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -112,14 +112,9 @@ func (h *sentPacketHandler) SetHandshakeComplete() { h.handshakeComplete = true } -func (h *sentPacketHandler) SentPacket(packet *Packet) error { - if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets { - return errors.New("Too many outstanding non-acked and non-retransmitted packets") - } - +func (h *sentPacketHandler) SentPacket(packet *Packet) { for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { h.skippedPackets = append(h.skippedPackets, p) - if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { h.skippedPackets = h.skippedPackets[1:] } @@ -144,7 +139,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) } - h.congestion.OnPacketSent( now, h.bytesInFlight, @@ -154,9 +148,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { ) h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) - h.updateLossDetectionAlarm(now) - return nil } func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 5418b0007..7af10d186 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -68,10 +68,8 @@ var _ = Describe("SentPacketHandler", func() { It("accepts two consecutive packets", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) @@ -82,11 +80,9 @@ var _ = Describe("SentPacketHandler", func() { It("accepts packet number 0", func() { packet1 := Packet{PacketNumber: 0, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) Expect(handler.lastSentPacketNumber).To(BeZero()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(0))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -96,14 +92,12 @@ var _ = Describe("SentPacketHandler", func() { It("stores the sent time", func() { packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} - err := handler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet) Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) It("does not store non-retransmittable packets", func() { - err := handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) Expect(handler.packetHistory.Len()).To(BeZero()) }) @@ -111,10 +105,8 @@ var _ = Describe("SentPacketHandler", func() { It("works with non-consecutive packet numbers", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(3))) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -129,12 +121,9 @@ var _ = Describe("SentPacketHandler", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 2} packet3 := Packet{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet3) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) + handler.SentPacket(&packet3) Expect(handler.skippedPackets).To(HaveLen(2)) Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 4})) }) @@ -142,10 +131,8 @@ var _ = Describe("SentPacketHandler", func() { It("recognizes multiple consecutive skipped packets", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.skippedPackets).To(HaveLen(2)) Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 3})) }) @@ -153,8 +140,7 @@ var _ = Describe("SentPacketHandler", func() { It("limits the lengths of the skipped packet slice", func() { for i := 0; i < protocol.MaxTrackedSkippedPackets+5; i++ { packet := Packet{PacketNumber: protocol.PacketNumber(2*i + 1), Frames: []wire.Frame{&streamFrame}, Length: 1} - err := handler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet) } Expect(handler.skippedPackets).To(HaveLen(protocol.MaxUndecryptablePackets)) Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(10))) @@ -186,20 +172,6 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("DoS mitigation", func() { - It("checks the size of the packet history, for unacked packets", func() { - i := protocol.PacketNumber(1) - for ; i <= protocol.MaxTrackedSentPackets; i++ { - err := handler.SentPacket(retransmittablePacket(i)) - Expect(err).ToNot(HaveOccurred()) - } - err := handler.SentPacket(retransmittablePacket(i)) - Expect(err).To(MatchError("Too many outstanding non-acked and non-retransmitted packets")) - }) - - // TODO: add a test that the length of the retransmission queue is considered, even if packets have already been ACKed. Relevant once we drop support for QUIC 33 and earlier - }) - Context("ACK processing", func() { var packets []*Packet @@ -219,8 +191,7 @@ var _ = Describe("SentPacketHandler", func() { {PacketNumber: 12, Frames: []wire.Frame{&streamFrame}, Length: 1}, } for _, packet := range packets { - err := handler.SentPacket(packet) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(packet) } // Increase RTT, because the tests would be flaky otherwise handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) @@ -327,7 +298,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("rejects an ACK that acks packets with a higher encryption level", func() { - err := handler.SentPacket(&Packet{ + handler.SentPacket(&Packet{ PacketNumber: 13, EncryptionLevel: protocol.EncryptionForwardSecure, Frames: []wire.Frame{&streamFrame}, @@ -337,8 +308,7 @@ var _ = Describe("SentPacketHandler", func() { LargestAcked: 13, LowestAcked: 13, } - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedAck(&ack, 1, protocol.EncryptionSecure, time.Now()) + err := handler.ReceivedAck(&ack, 1, protocol.EncryptionSecure, time.Now()) Expect(err).To(MatchError("Received ACK with encryption level encrypted (not forward-secure) that acks a packet 13 (encryption level forward-secure)")) }) @@ -518,8 +488,7 @@ var _ = Describe("SentPacketHandler", func() { {PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1}, } for _, packet := range morePackets { - err := handler.SentPacket(packet) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(packet) } }) @@ -625,14 +594,11 @@ var _ = Describe("SentPacketHandler", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 2} packet3 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 3} - err := handler.SentPacket(&packet1) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet1) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) - err = handler.SentPacket(&packet2) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet2) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2))) - err = handler.SentPacket(&packet3) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet3) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2 + 3))) // Increase RTT, because the tests would be flaky otherwise @@ -647,8 +613,7 @@ var _ = Describe("SentPacketHandler", func() { {First: 1, Last: 1}, }, } - err = handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) - Expect(err).NotTo(HaveOccurred()) + handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour) @@ -680,8 +645,7 @@ var _ = Describe("SentPacketHandler", func() { Length: 42, Frames: []wire.Frame{&wire.PingFrame{}}, } - err := handler.SentPacket(p) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(p) }) It("should call MaybeExitSlowStart and OnPacketAcked", func() { @@ -805,13 +769,11 @@ var _ = Describe("SentPacketHandler", func() { Context("Delay-based loss detection", func() { It("detects a packet as lost", func() { - err := handler.SentPacket(retransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(retransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) Expect(handler.lossTime.IsZero()).To(BeTrue()) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now().Add(time.Hour)) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now().Add(time.Hour)) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeFalse()) @@ -826,15 +788,12 @@ var _ = Describe("SentPacketHandler", func() { }) It("does not detect packets as lost without ACKs", func() { - err := handler.SentPacket(nonRetransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(3)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(nonRetransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) + handler.SentPacket(retransmittablePacket(3)) Expect(handler.lossTime.IsZero()).To(BeTrue()) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now().Add(time.Hour)) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now().Add(time.Hour)) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeTrue()) Expect(time.Until(handler.GetAlarmTimeout())).To(BeNumerically("~", handler.computeRTOTimeout(), time.Minute)) @@ -855,16 +814,12 @@ var _ = Describe("SentPacketHandler", func() { It("detects the handshake timeout", func() { // send handshake packets: 1, 2, 4 // send a forward-secure packet: 3 - err := handler.SentPacket(handshakePacket(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(handshakePacket(2)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(3)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(handshakePacket(4)) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(handshakePacket(1)) + handler.SentPacket(handshakePacket(2)) + handler.SentPacket(retransmittablePacket(3)) + handler.SentPacket(handshakePacket(4)) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now()) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeTrue()) handshakeTimeout := handler.computeHandshakeTimeout() @@ -887,10 +842,8 @@ var _ = Describe("SentPacketHandler", func() { Context("RTO retransmission", func() { It("queues two packets if RTO expires", func() { - err := handler.SentPacket(retransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(retransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) Expect(handler.lossTime.IsZero()).To(BeTrue()) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index b28a970ad..80bb62c41 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -132,10 +132,8 @@ func (mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call { } // SentPacket mocks base method -func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) error { - ret := m.ctrl.Call(m, "SentPacket", arg0) - ret0, _ := ret[0].(error) - return ret0 +func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { + m.ctrl.Call(m, "SentPacket", arg0) } // SentPacket indicates an expected call of SentPacket diff --git a/session.go b/session.go index 31e2296fe..67d19480c 100644 --- a/session.go +++ b/session.go @@ -877,16 +877,13 @@ func (s *session) sendPacket() (bool, error) { func (s *session) sendPackedPacket(packet *packedPacket) error { defer putPacketBuffer(&packet.raw) - err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ + s.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: packet.header.PacketNumber, PacketType: packet.header.Type, Frames: packet.frames, Length: protocol.ByteCount(len(packet.raw)), EncryptionLevel: packet.encryptionLevel, }) - if err != nil { - return err - } s.logPacket(packet) return s.conn.Write(packet.raw) } diff --git a/session_test.go b/session_test.go index 97f5780f9..3bfaa1529 100644 --- a/session_test.go +++ b/session_test.go @@ -1091,7 +1091,7 @@ var _ = Describe("Session", func() { Expect(sess.rttStats.SmoothedRTT()).To(Equal(rtt)) // make sure it worked sess.packer.packetNumberGenerator.next = n + 1 // Now, we send a single packet, and expect that it was retransmitted later - err := sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ + sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: n, Length: 1, Frames: []wire.Frame{&wire.StreamFrame{ @@ -1099,7 +1099,6 @@ var _ = Describe("Session", func() { }}, EncryptionLevel: protocol.EncryptionForwardSecure, }) - Expect(err).NotTo(HaveOccurred()) go sess.run() defer sess.Close(nil) sess.scheduleSending()