diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index ac74f855..5777d97a 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -27,7 +27,7 @@ type Packet struct { type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) - ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error + ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3c4eab27..bbd1fb44 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -277,12 +277,12 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit return isAckEliciting } -func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { +func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { pnSpace := h.getPacketNumberSpace(encLevel) largestAcked := ack.LargestAcked() if largestAcked > pnSpace.largestSent { - return qerr.NewError(qerr.ProtocolViolation, "Received ACK for an unsent packet") + return false, qerr.NewError(qerr.ProtocolViolation, "Received ACK for an unsent packet") } pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) @@ -299,7 +299,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En priorInFlight := h.bytesInFlight ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel) if err != nil || len(ackedPackets) == 0 { - return err + return false, err } // update the RTT, if the largest acked is newly acknowledged if len(ackedPackets) > 0 { @@ -317,12 +317,16 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } if err := h.detectLostPackets(rcvTime, encLevel); err != nil { - return err + return false, err } + var acked1RTTPacket bool for _, p := range ackedPackets { if p.includedInBytesInFlight && !p.declaredLost { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } + if p.EncryptionLevel == protocol.Encryption1RTT { + acked1RTTPacket = true + } h.removeFromBytesInFlight(p) } @@ -341,7 +345,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En pnSpace.history.DeleteOldPackets(rcvTime) h.setLossDetectionTimer() - return nil + return acked1RTTPacket, nil } func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index fef39e1b..20fa985f 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -153,21 +153,44 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) }) - Context("ACK validation", func() { + Context("ACK processing", func() { It("accepts ACKs sent in packet 0", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) }) + It("says if a 1-RTT packet was acknowledged", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 101, EncryptionLevel: protocol.Encryption0RTT})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102, EncryptionLevel: protocol.Encryption1RTT})) + acked1RTT, err := handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 101}}}, + protocol.Encryption1RTT, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(acked1RTT).To(BeFalse()) + acked1RTT, err = handler.ReceivedAck( + &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 101, Largest: 102}}}, + protocol.Encryption1RTT, + time.Now(), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(acked1RTT).To(BeTrue()) + }) + It("accepts multiple ACKs sent in the same packet", func() { ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}} - Expect(handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) // this wouldn't happen in practice // for testing purposes, we pretend send a different ACK frame in a duplicated packet, to be able to verify that it actually doesn't get processed - Expect(handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(4))) }) @@ -175,20 +198,24 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 100})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 102}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(MatchError("received an ACK for skipped packet number: 101 (1-RTT)")) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).To(MatchError("received an ACK for skipped packet number: 101 (1-RTT)")) }) It("rejects ACKs with a too high LargestAcked packet number", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) }) It("ignores repeated ACKs", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(3))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) @@ -221,7 +248,8 @@ var _ = Describe("SentPacketHandler", func() { It("adjusts the LargestAcked, and adjusts the bytes in flight", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.appDataPackets.largestAcked).To(Equal(protocol.PacketNumber(5))) expectInPacketHistoryOrLost([]protocol.PacketNumber{6, 7, 8, 9}, protocol.Encryption1RTT) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4))) @@ -229,7 +257,8 @@ var _ = Describe("SentPacketHandler", func() { It("acks packet 0", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 0}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(getPacket(0, protocol.Encryption1RTT)).To(BeNil()) expectInPacketHistoryOrLost([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9}, protocol.Encryption1RTT) }) @@ -247,7 +276,8 @@ var _ = Describe("SentPacketHandler", func() { }}, })) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(acked).To(BeTrue()) }) @@ -258,13 +288,15 @@ var _ = Describe("SentPacketHandler", func() { {Smallest: 1, Largest: 3}, }, } - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 4, 5}, protocol.Encryption1RTT) }) It("does not ack packets below the LowestAcked", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 8}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 1, 2, 9}, protocol.Encryption1RTT) }) @@ -277,7 +309,8 @@ var _ = Describe("SentPacketHandler", func() { {Smallest: 1, Largest: 1}, }, } - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 2, 4, 5, 8}, protocol.Encryption1RTT) }) @@ -288,16 +321,19 @@ var _ = Describe("SentPacketHandler", func() { {Smallest: 1, Largest: 4}, }, } - Expect(handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 5, 7, 8, 9}, protocol.Encryption1RTT) ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} // now ack 5 - Expect(handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) }) It("processes an ACK that contains old ACK ranges", func() { ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} - Expect(handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 8, 9}, protocol.Encryption1RTT) ack2 := &wire.AckFrame{ AckRanges: []wire.AckRange{ @@ -306,7 +342,8 @@ var _ = Describe("SentPacketHandler", func() { {Smallest: 1, Largest: 1}, }, } - Expect(handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistoryOrLost([]protocol.PacketNumber{0, 7, 9}, protocol.Encryption1RTT) }) }) @@ -320,13 +357,16 @@ var _ = Describe("SentPacketHandler", func() { getPacket(6, protocol.Encryption1RTT).SendTime = now.Add(-1 * time.Minute) // Now, check that the proper times are used when calculating the deltas ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second)) }) @@ -340,7 +380,8 @@ var _ = Describe("SentPacketHandler", func() { AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, DelayTime: 5 * time.Minute, } - Expect(handler.ReceivedAck(ack, protocol.EncryptionInitial, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second)) }) @@ -353,7 +394,8 @@ var _ = Describe("SentPacketHandler", func() { AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, DelayTime: 5 * time.Minute, } - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) }) @@ -366,7 +408,8 @@ var _ = Describe("SentPacketHandler", func() { AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, DelayTime: 5 * time.Minute, } - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 9*time.Minute, 1*time.Second)) }) }) @@ -402,25 +445,30 @@ var _ = Describe("SentPacketHandler", func() { It("determines which ACK we have received an ACK for", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) }) It("doesn't do anything when the acked packet didn't contain an ACK", func() { ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}} - Expect(handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) - Expect(handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) }) It("doesn't decrease the value", func() { ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}} ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - Expect(handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack1, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) - Expect(handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack2, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) }) }) @@ -462,7 +510,8 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, rcvTime)).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, rcvTime) + Expect(err).ToNot(HaveOccurred()) }) It("doesn't call OnPacketAcked when a retransmitted packet is acked", func() { @@ -476,10 +525,12 @@ var _ = Describe("SentPacketHandler", func() { cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) // don't EXPECT any further calls to the congestion controller ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) }) It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() { @@ -498,7 +549,8 @@ var _ = Describe("SentPacketHandler", func() { cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(mtuPacketDeclaredLost).To(BeTrue()) Expect(handler.bytesInFlight).To(BeZero()) }) @@ -516,7 +568,8 @@ var _ = Describe("SentPacketHandler", func() { cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute)) + Expect(err).ToNot(HaveOccurred()) // receive the second ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), @@ -524,7 +577,8 @@ var _ = Describe("SentPacketHandler", func() { cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err = handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) }) It("passes the bytes in flight to the congestion controller", func() { @@ -587,7 +641,8 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) }) @@ -636,7 +691,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.ptoCount).To(BeZero()) }) @@ -793,7 +849,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOAppData)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.SendMode()).To(Equal(SendAny)) }) @@ -867,11 +924,12 @@ var _ = Describe("SentPacketHandler", func() { It("sends an Initial packet to unblock the server", func() { handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - Expect(handler.ReceivedAck( + _, err := handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now(), - )).To(Succeed()) + ) + Expect(err).ToNot(HaveOccurred()) // No packets are outstanding at this point. // Make sure that a probe packet is sent. Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) @@ -885,22 +943,24 @@ var _ = Describe("SentPacketHandler", func() { // Now receive an ACK for a Handshake packet. // This tells the client that the server completed address validation. handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1})) - Expect(handler.ReceivedAck( + _, err = handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionHandshake, time.Now(), - )).To(Succeed()) + ) + Expect(err).ToNot(HaveOccurred()) // Make sure that no timer is set at this point. Expect(handler.GetLossDetectionTimeout()).To(BeZero()) }) It("sends a Handshake packet to unblock the server, if Initial keys were already dropped", func() { handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) - Expect(handler.ReceivedAck( + _, err := handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now(), - )).To(Succeed()) + ) + Expect(err).ToNot(HaveOccurred()) handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1})) // also drops Initial packets Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) @@ -908,11 +968,12 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) // Now receive an ACK for this packet, and send another one. - Expect(handler.ReceivedAck( + _, err = handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionHandshake, time.Now(), - )).To(Succeed()) + ) + Expect(err).ToNot(HaveOccurred()) handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 2})) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) }) @@ -929,11 +990,12 @@ var _ = Describe("SentPacketHandler", func() { It("correctly sets the timer after the Initial packet number space has been dropped", func() { handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-42 * time.Second)})) - Expect(handler.ReceivedAck( + _, err := handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now(), - )).To(Succeed()) + ) + Expect(err).ToNot(HaveOccurred()) handler.SentPacket(handshakePacketNonAckEliciting(&Packet{PacketNumber: 1, SendTime: time.Now()})) Expect(handler.initialPackets).To(BeNil()) @@ -950,7 +1012,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOInitial)) Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now())).To(Succeed()) + _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now()) + Expect(err).ToNot(HaveOccurred()) Expect(handler.ptoCount).To(BeEquivalentTo(1)) }) }) @@ -962,7 +1025,8 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) } ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, now)).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) + Expect(err).ToNot(HaveOccurred()) expectInPacketHistory([]protocol.PacketNumber{4, 5}, protocol.Encryption1RTT) Expect(lostPackets).To(Equal([]protocol.PacketNumber{1, 2, 3})) }) @@ -976,7 +1040,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, now)).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now) + Expect(err).ToNot(HaveOccurred()) // no need to set an alarm, since packet 1 was already declared lost Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) Expect(handler.bytesInFlight).To(BeZero()) @@ -992,7 +1057,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second))).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. @@ -1014,7 +1080,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - Expect(handler.ReceivedAck(ack, protocol.EncryptionInitial, now.Add(-time.Second))).To(Succeed()) + _, err := handler.ReceivedAck(ack, protocol.EncryptionInitial, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second)) // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. @@ -1035,7 +1102,8 @@ var _ = Describe("SentPacketHandler", func() { EncryptionLevel: protocol.Encryption1RTT, })) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} - Expect(handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now())).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) + _, err := handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now()) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) }) It("deletes Initial packets, as a server", func() { diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index c4c9ef7b..0b1390b5 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -135,11 +135,12 @@ func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) } // ReceivedAck mocks base method. -func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) error { +func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.EncryptionLevel, arg2 time.Time) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ReceivedAck indicates an expected call of ReceivedAck. diff --git a/session.go b/session.go index 0dd9f828..9899f309 100644 --- a/session.go +++ b/session.go @@ -1364,12 +1364,16 @@ func (s *session) handleHandshakeDoneFrame() error { } func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - if err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime); err != nil { + acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) + if err != nil { return err } - if encLevel != protocol.Encryption1RTT { + if !acked1RTTPacket { return nil } + if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } diff --git a/session_test.go b/session_test.go index fdfc80a5..699b34cc 100644 --- a/session_test.go +++ b/session_test.go @@ -2564,6 +2564,18 @@ var _ = Describe("Client Session", func() { Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) }) + It("interprets an ACK for 1-RTT packets as confirmation of the handshake", func() { + sess.peerParams = &wire.TransportParameters{} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}} + sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil) + sph.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) + cryptoSetup.EXPECT().SetHandshakeConfirmed() + Expect(sess.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) + }) + Context("handling tokens", func() { var mockTokenStore *MockTokenStore