From 2d64953e0ec5af764146cd4ee5bb7b16f8dd064f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 5 Mar 2018 17:08:18 +0700 Subject: [PATCH] remove incorrect error check when sending a packet There's no need for a check if more than protocol.MaxTrackedSentPackets packets were sent. There are certain situations where we allow (via SendingAllowed()) sending of more packets, and we shouldn't throw an error when the session then actually sends these packets. --- internal/ackhandler/interfaces.go | 2 +- internal/ackhandler/sent_packet_handler.go | 10 +- .../ackhandler/sent_packet_handler_test.go | 121 ++++++------------ .../mocks/ackhandler/sent_packet_handler.go | 6 +- session.go | 5 +- session_test.go | 3 +- 6 files changed, 43 insertions(+), 104 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 09b2c0172..9ec702165 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -10,7 +10,7 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(packet *Packet) error + SentPacket(packet *Packet) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error SetHandshakeComplete() diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 11fa2f911..786d42f50 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -112,14 +112,9 @@ func (h *sentPacketHandler) SetHandshakeComplete() { h.handshakeComplete = true } -func (h *sentPacketHandler) SentPacket(packet *Packet) error { - if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets { - return errors.New("Too many outstanding non-acked and non-retransmitted packets") - } - +func (h *sentPacketHandler) SentPacket(packet *Packet) { for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { h.skippedPackets = append(h.skippedPackets, p) - if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { h.skippedPackets = h.skippedPackets[1:] } @@ -144,7 +139,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) } - h.congestion.OnPacketSent( now, h.bytesInFlight, @@ -154,9 +148,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { ) h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) - h.updateLossDetectionAlarm(now) - return nil } func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 5418b0007..7af10d186 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -68,10 +68,8 @@ var _ = Describe("SentPacketHandler", func() { It("accepts two consecutive packets", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) @@ -82,11 +80,9 @@ var _ = Describe("SentPacketHandler", func() { It("accepts packet number 0", func() { packet1 := Packet{PacketNumber: 0, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) Expect(handler.lastSentPacketNumber).To(BeZero()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(0))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -96,14 +92,12 @@ var _ = Describe("SentPacketHandler", func() { It("stores the sent time", func() { packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} - err := handler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet) Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) It("does not store non-retransmittable packets", func() { - err := handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) Expect(handler.packetHistory.Len()).To(BeZero()) }) @@ -111,10 +105,8 @@ var _ = Describe("SentPacketHandler", func() { It("works with non-consecutive packet numbers", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(3))) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -129,12 +121,9 @@ var _ = Describe("SentPacketHandler", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 2} packet3 := Packet{PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet3) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) + handler.SentPacket(&packet3) Expect(handler.skippedPackets).To(HaveLen(2)) Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 4})) }) @@ -142,10 +131,8 @@ var _ = Describe("SentPacketHandler", func() { It("recognizes multiple consecutive skipped packets", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 2} - err := handler.SentPacket(&packet1) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet1) + handler.SentPacket(&packet2) Expect(handler.skippedPackets).To(HaveLen(2)) Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 3})) }) @@ -153,8 +140,7 @@ var _ = Describe("SentPacketHandler", func() { It("limits the lengths of the skipped packet slice", func() { for i := 0; i < protocol.MaxTrackedSkippedPackets+5; i++ { packet := Packet{PacketNumber: protocol.PacketNumber(2*i + 1), Frames: []wire.Frame{&streamFrame}, Length: 1} - err := handler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(&packet) } Expect(handler.skippedPackets).To(HaveLen(protocol.MaxUndecryptablePackets)) Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(10))) @@ -186,20 +172,6 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("DoS mitigation", func() { - It("checks the size of the packet history, for unacked packets", func() { - i := protocol.PacketNumber(1) - for ; i <= protocol.MaxTrackedSentPackets; i++ { - err := handler.SentPacket(retransmittablePacket(i)) - Expect(err).ToNot(HaveOccurred()) - } - err := handler.SentPacket(retransmittablePacket(i)) - Expect(err).To(MatchError("Too many outstanding non-acked and non-retransmitted packets")) - }) - - // TODO: add a test that the length of the retransmission queue is considered, even if packets have already been ACKed. Relevant once we drop support for QUIC 33 and earlier - }) - Context("ACK processing", func() { var packets []*Packet @@ -219,8 +191,7 @@ var _ = Describe("SentPacketHandler", func() { {PacketNumber: 12, Frames: []wire.Frame{&streamFrame}, Length: 1}, } for _, packet := range packets { - err := handler.SentPacket(packet) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(packet) } // Increase RTT, because the tests would be flaky otherwise handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) @@ -327,7 +298,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("rejects an ACK that acks packets with a higher encryption level", func() { - err := handler.SentPacket(&Packet{ + handler.SentPacket(&Packet{ PacketNumber: 13, EncryptionLevel: protocol.EncryptionForwardSecure, Frames: []wire.Frame{&streamFrame}, @@ -337,8 +308,7 @@ var _ = Describe("SentPacketHandler", func() { LargestAcked: 13, LowestAcked: 13, } - Expect(err).ToNot(HaveOccurred()) - err = handler.ReceivedAck(&ack, 1, protocol.EncryptionSecure, time.Now()) + err := handler.ReceivedAck(&ack, 1, protocol.EncryptionSecure, time.Now()) Expect(err).To(MatchError("Received ACK with encryption level encrypted (not forward-secure) that acks a packet 13 (encryption level forward-secure)")) }) @@ -518,8 +488,7 @@ var _ = Describe("SentPacketHandler", func() { {PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1}, } for _, packet := range morePackets { - err := handler.SentPacket(packet) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(packet) } }) @@ -625,14 +594,11 @@ var _ = Describe("SentPacketHandler", func() { packet1 := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} packet2 := Packet{PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 2} packet3 := Packet{PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 3} - err := handler.SentPacket(&packet1) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet1) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) - err = handler.SentPacket(&packet2) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet2) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2))) - err = handler.SentPacket(&packet3) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(&packet3) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2 + 3))) // Increase RTT, because the tests would be flaky otherwise @@ -647,8 +613,7 @@ var _ = Describe("SentPacketHandler", func() { {First: 1, Last: 1}, }, } - err = handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) - Expect(err).NotTo(HaveOccurred()) + handler.ReceivedAck(&ack, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour) @@ -680,8 +645,7 @@ var _ = Describe("SentPacketHandler", func() { Length: 42, Frames: []wire.Frame{&wire.PingFrame{}}, } - err := handler.SentPacket(p) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(p) }) It("should call MaybeExitSlowStart and OnPacketAcked", func() { @@ -805,13 +769,11 @@ var _ = Describe("SentPacketHandler", func() { Context("Delay-based loss detection", func() { It("detects a packet as lost", func() { - err := handler.SentPacket(retransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(retransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) Expect(handler.lossTime.IsZero()).To(BeTrue()) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now().Add(time.Hour)) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now().Add(time.Hour)) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeFalse()) @@ -826,15 +788,12 @@ var _ = Describe("SentPacketHandler", func() { }) It("does not detect packets as lost without ACKs", func() { - err := handler.SentPacket(nonRetransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(3)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(nonRetransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) + handler.SentPacket(retransmittablePacket(3)) Expect(handler.lossTime.IsZero()).To(BeTrue()) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now().Add(time.Hour)) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now().Add(time.Hour)) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeTrue()) Expect(time.Until(handler.GetAlarmTimeout())).To(BeNumerically("~", handler.computeRTOTimeout(), time.Minute)) @@ -855,16 +814,12 @@ var _ = Describe("SentPacketHandler", func() { It("detects the handshake timeout", func() { // send handshake packets: 1, 2, 4 // send a forward-secure packet: 3 - err := handler.SentPacket(handshakePacket(1)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(handshakePacket(2)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(3)) - Expect(err).ToNot(HaveOccurred()) - err = handler.SentPacket(handshakePacket(4)) - Expect(err).ToNot(HaveOccurred()) + handler.SentPacket(handshakePacket(1)) + handler.SentPacket(handshakePacket(2)) + handler.SentPacket(retransmittablePacket(3)) + handler.SentPacket(handshakePacket(4)) - err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now()) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, protocol.EncryptionSecure, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.lossTime.IsZero()).To(BeTrue()) handshakeTimeout := handler.computeHandshakeTimeout() @@ -887,10 +842,8 @@ var _ = Describe("SentPacketHandler", func() { Context("RTO retransmission", func() { It("queues two packets if RTO expires", func() { - err := handler.SentPacket(retransmittablePacket(1)) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(retransmittablePacket(2)) - Expect(err).NotTo(HaveOccurred()) + handler.SentPacket(retransmittablePacket(1)) + handler.SentPacket(retransmittablePacket(2)) handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) Expect(handler.lossTime.IsZero()).To(BeTrue()) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index b28a970ad..80bb62c41 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -132,10 +132,8 @@ func (mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call { } // SentPacket mocks base method -func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) error { - ret := m.ctrl.Call(m, "SentPacket", arg0) - ret0, _ := ret[0].(error) - return ret0 +func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) { + m.ctrl.Call(m, "SentPacket", arg0) } // SentPacket indicates an expected call of SentPacket diff --git a/session.go b/session.go index 31e2296fe..67d19480c 100644 --- a/session.go +++ b/session.go @@ -877,16 +877,13 @@ func (s *session) sendPacket() (bool, error) { func (s *session) sendPackedPacket(packet *packedPacket) error { defer putPacketBuffer(&packet.raw) - err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ + s.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: packet.header.PacketNumber, PacketType: packet.header.Type, Frames: packet.frames, Length: protocol.ByteCount(len(packet.raw)), EncryptionLevel: packet.encryptionLevel, }) - if err != nil { - return err - } s.logPacket(packet) return s.conn.Write(packet.raw) } diff --git a/session_test.go b/session_test.go index 97f5780f9..3bfaa1529 100644 --- a/session_test.go +++ b/session_test.go @@ -1091,7 +1091,7 @@ var _ = Describe("Session", func() { Expect(sess.rttStats.SmoothedRTT()).To(Equal(rtt)) // make sure it worked sess.packer.packetNumberGenerator.next = n + 1 // Now, we send a single packet, and expect that it was retransmitted later - err := sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ + sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: n, Length: 1, Frames: []wire.Frame{&wire.StreamFrame{ @@ -1099,7 +1099,6 @@ var _ = Describe("Session", func() { }}, EncryptionLevel: protocol.EncryptionForwardSecure, }) - Expect(err).NotTo(HaveOccurred()) go sess.run() defer sess.Close(nil) sess.scheduleSending()