diff --git a/framer.go b/framer.go index 382515048..f1f6f4351 100644 --- a/framer.go +++ b/framer.go @@ -10,6 +10,8 @@ import ( ) type framer interface { + HasData() bool + QueueControlFrame(wire.Frame) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) @@ -43,6 +45,19 @@ func newFramer( } } +func (f *framerI) HasData() bool { + f.mutex.Lock() + hasData := len(f.streamQueue) > 0 + f.mutex.Unlock() + if hasData { + return true + } + f.controlFrameMutex.Lock() + hasData = len(f.controlFrames) > 0 + f.controlFrameMutex.Unlock() + return hasData +} + func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Lock() f.controlFrames = append(f.controlFrames, frame) diff --git a/framer_test.go b/framer_test.go index 5c89d63e9..04caf767e 100644 --- a/framer_test.go +++ b/framer_test.go @@ -50,6 +50,16 @@ var _ = Describe("Framer", func() { Expect(length).To(Equal(mdf.Length(version) + msf.Length(version))) }) + It("says if it has data", func() { + Expect(framer.HasData()).To(BeFalse()) + f := &wire.MaxDataFrame{ByteOffset: 0x42} + framer.QueueControlFrame(f) + Expect(framer.HasData()).To(BeTrue()) + frames, _ := framer.AppendControlFrames(nil, 1000) + Expect(frames).To(HaveLen(1)) + Expect(framer.HasData()).To(BeFalse()) + }) + It("appends to the slice given", func() { ping := &wire.PingFrame{} mdf := &wire.MaxDataFrame{ByteOffset: 0x42} @@ -99,6 +109,25 @@ var _ = Describe("Framer", func() { Expect(length).To(Equal(f.Length(version))) }) + It("says if it has data", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil).Times(2) + Expect(framer.HasData()).To(BeFalse()) + framer.AddActiveStream(id1) + Expect(framer.HasData()).To(BeTrue()) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foo")} + f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("bar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f1}, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&ackhandler.Frame{Frame: f2}, false) + frames, _ := framer.AppendStreamFrames(nil, protocol.MaxByteCount) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f1)) + Expect(framer.HasData()).To(BeTrue()) + frames, _ = framer.AppendStreamFrames(nil, protocol.MaxByteCount) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) + Expect(framer.HasData()).To(BeFalse()) + }) + It("appends to a frame slice", func() { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) f := &wire.StreamFrame{ diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go new file mode 100644 index 000000000..88b368704 --- /dev/null +++ b/integrationtests/self/packetization_test.go @@ -0,0 +1,113 @@ +package self_test + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go" + quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Packetization", func() { + var ( + server quic.Listener + proxy *quicproxy.QuicProxy + incoming uint32 + outgoing uint32 + ) + + BeforeEach(func() { + incoming = 0 + outgoing = 0 + var err error + server, err = quic.ListenAddr( + "localhost:0", + getTLSConfig(), + getQuicConfigForServer(&quic.Config{AcceptToken: func(net.Addr, *quic.Token) bool { return true }}), + ) + Expect(err).ToNot(HaveOccurred()) + serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) + + proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: serverAddr, + DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { + switch dir { + case quicproxy.DirectionIncoming: + atomic.AddUint32(&incoming, 1) + case quicproxy.DirectionOutgoing: + atomic.AddUint32(&outgoing, 1) + } + return 5 * time.Millisecond + }, + }) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + Expect(proxy.Close()).To(Succeed()) + Expect(server.Close()).To(Succeed()) + }) + + // In this test, the client sends 100 small messages. The server echoes these messages. + // This means that every endpoint will send 100 ack-eliciting packets in short succession. + // This test then tests that no more than 110 packets are sent in every direction, making sure that ACK are bundled. + It("bundles ACKs", func() { + const numMsg = 100 + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfigForClient(nil), + ) + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + sess, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 1) + // Echo every byte received from the client. + for { + if _, err := str.Read(b); err != nil { + break + } + _, err = str.Write(b) + Expect(err).ToNot(HaveOccurred()) + } + }() + + str, err := sess.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 1) + // Send numMsg 1-byte messages. + for i := 0; i < numMsg; i++ { + _, err = str.Write([]byte{uint8(i)}) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(b[0]).To(Equal(uint8(i))) + } + Expect(sess.CloseWithError(0, "")).To(Succeed()) + + numIncoming := atomic.LoadUint32(&incoming) + numOutgoing := atomic.LoadUint32(&outgoing) + fmt.Fprintf(GinkgoWriter, "incoming packets: %d\n", numIncoming) + fmt.Fprintf(GinkgoWriter, "outgoing packets: %d\n", numOutgoing) + Expect(numIncoming).To(And( + BeNumerically(">", numMsg), + BeNumerically("<", numMsg+10), + )) + Expect(numOutgoing).To(And( + BeNumerically(">", numMsg), + BeNumerically("<", numMsg+10), + )) + }) +}) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index d6109f91f..a233704d9 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -68,5 +68,5 @@ type ReceivedPacketHandler interface { DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time - GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame } diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 814addb2e..1b7076395 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -113,20 +113,20 @@ func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) } -func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *wire.AckFrame { +func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { var ack *wire.AckFrame switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { - ack = h.initialPackets.GetAckFrame() + ack = h.initialPackets.GetAckFrame(onlyIfQueued) } case protocol.EncryptionHandshake: if h.handshakePackets != nil { - ack = h.handshakePackets.GetAckFrame() + ack = h.handshakePackets.GetAckFrame(onlyIfQueued) } case protocol.Encryption1RTT: // 0-RTT packets can't contain ACK frames - return h.appDataPackets.GetAckFrame() + return h.appDataPackets.GetAckFrame(onlyIfQueued) default: return nil } diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index deb323727..7fbf21190 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -40,17 +40,17 @@ var _ = Describe("Received Packet Handler", func() { Expect(handler.ReceivedPacket(3, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - initialAck := handler.GetAckFrame(protocol.EncryptionInitial) + initialAck := handler.GetAckFrame(protocol.EncryptionInitial, true) Expect(initialAck).ToNot(BeNil()) Expect(initialAck.AckRanges).To(HaveLen(1)) Expect(initialAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) Expect(initialAck.DelayTime).To(BeZero()) - handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake) + handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake, true) Expect(handshakeAck).ToNot(BeNil()) Expect(handshakeAck.AckRanges).To(HaveLen(1)) Expect(handshakeAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) Expect(handshakeAck.DelayTime).To(BeZero()) - oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT) + oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(oneRTTAck).ToNot(BeNil()) Expect(oneRTTAck.AckRanges).To(HaveLen(1)) Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) @@ -64,7 +64,7 @@ var _ = Describe("Received Packet Handler", func() { sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack := handler.GetAckFrame(protocol.Encryption1RTT) + ack := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.AckRanges).To(HaveLen(1)) Expect(ack.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) @@ -93,10 +93,10 @@ var _ = Describe("Received Packet Handler", func() { sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) - Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil()) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) }) It("drops Handshake packets", func() { @@ -105,10 +105,10 @@ var _ = Describe("Received Packet Handler", func() { sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) - Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil()) - Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.Encryption1RTT, true)).ToNot(BeNil()) }) It("does nothing when dropping 0-RTT packets", func() { @@ -121,7 +121,7 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack := handler.GetAckFrame(protocol.Encryption1RTT) + ack := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) @@ -129,7 +129,7 @@ var _ = Describe("Received Packet Handler", func() { Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(2)) Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - ack = handler.GetAckFrame(protocol.Encryption1RTT) + ack = handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(2))) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) diff --git a/internal/ackhandler/received_packet_history.go b/internal/ackhandler/received_packet_history.go index 6bf398c9d..63d3f8ee3 100644 --- a/internal/ackhandler/received_packet_history.go +++ b/internal/ackhandler/received_packet_history.go @@ -22,25 +22,26 @@ func newReceivedPacketHistory() *receivedPacketHistory { } // ReceivedPacket registers a packet with PacketNumber p and updates the ranges -func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { +func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { // ignore delayed packets, if we already deleted the range if p < h.deletedBelow { - return + return false } - h.addToRanges(p) + isNew := h.addToRanges(p) h.maybeDeleteOldRanges() + return isNew } -func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) { +func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { if h.ranges.Len() == 0 { h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) - return + return true } for el := h.ranges.Back(); el != nil; el = el.Prev() { // p already included in an existing range. Nothing to do here if p >= el.Value.Start && p <= el.Value.End { - return + return false } var rangeExtended bool @@ -58,20 +59,20 @@ func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) { if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges prev.Value.End = el.Value.End h.ranges.Remove(el) - return } - return // if the two ranges were not merge, we're done here + return true // if the two ranges were not merge, we're done here } // create a new range at the end if p > el.Value.End { h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) - return + return true } } // create a new range at the beginning h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) + return true } // Delete old ranges, if we're tracking more than 500 of them. diff --git a/internal/ackhandler/received_packet_history_test.go b/internal/ackhandler/received_packet_history_test.go index d55d9bf55..6c48a70e8 100644 --- a/internal/ackhandler/received_packet_history_test.go +++ b/internal/ackhandler/received_packet_history_test.go @@ -19,54 +19,55 @@ var _ = Describe("receivedPacketHistory", func() { Context("ranges", func() { It("adds the first packet", func() { - hist.ReceivedPacket(4) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) }) It("doesn't care about duplicate packets", func() { - hist.ReceivedPacket(4) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeFalse()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) }) It("adds a few consecutive packets", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) }) It("doesn't care about a duplicate packet contained in an existing range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeFalse()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) }) It("extends a range at the front", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(3) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(3)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4})) }) It("creates a new range when a packet is lost", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(6) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) }) It("creates a new range in between two ranges", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(10) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) - hist.ReceivedPacket(7) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(3)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) @@ -74,47 +75,47 @@ var _ = Describe("receivedPacketHistory", func() { }) It("creates a new range before an existing range for a belated packet", func() { - hist.ReceivedPacket(6) - hist.ReceivedPacket(4) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6})) }) It("extends a previous range at the end", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(7) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5})) Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7})) }) It("extends a range at the front", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(7) - hist.ReceivedPacket(6) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7})) }) It("closes a range", func() { - hist.ReceivedPacket(6) - hist.ReceivedPacket(4) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(2)) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) }) It("closes a range in the middle", func() { - hist.ReceivedPacket(1) - hist.ReceivedPacket(10) - hist.ReceivedPacket(4) - hist.ReceivedPacket(6) + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(4)) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) Expect(hist.ranges.Len()).To(Equal(3)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1})) Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6})) @@ -129,38 +130,38 @@ var _ = Describe("receivedPacketHistory", func() { }) It("deletes a range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(10) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) hist.DeleteBelow(6) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) }) It("deletes multiple ranges", func() { - hist.ReceivedPacket(1) - hist.ReceivedPacket(5) - hist.ReceivedPacket(10) + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) hist.DeleteBelow(8) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10})) }) It("adjusts a range, if packets are delete from an existing range", func() { - hist.ReceivedPacket(3) - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) - hist.ReceivedPacket(7) + Expect(hist.ReceivedPacket(3)).To(BeTrue()) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) hist.DeleteBelow(5) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7})) }) It("adjusts a range, if only one packet remains in the range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(10) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) hist.DeleteBelow(5) Expect(hist.ranges.Len()).To(Equal(2)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5})) @@ -168,27 +169,27 @@ var _ = Describe("receivedPacketHistory", func() { }) It("keeps a one-packet range, if deleting up to the packet directly below", func() { - hist.ReceivedPacket(4) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) hist.DeleteBelow(4) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4})) }) It("doesn't add delayed packets below deleted ranges", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) hist.DeleteBelow(5) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) - hist.ReceivedPacket(2) + Expect(hist.ReceivedPacket(2)).To(BeFalse()) Expect(hist.ranges.Len()).To(Equal(1)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 6})) }) It("doesn't create more than MaxNumAckRanges ranges", func() { for i := protocol.PacketNumber(0); i < protocol.MaxNumAckRanges; i++ { - hist.ReceivedPacket(2 * i) + Expect(hist.ReceivedPacket(2 * i)).To(BeTrue()) } Expect(hist.ranges.Len()).To(Equal(protocol.MaxNumAckRanges)) Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 0, End: 0})) @@ -205,21 +206,21 @@ var _ = Describe("receivedPacketHistory", func() { }) It("gets a single ACK range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) ackRanges := hist.GetAckRanges() Expect(ackRanges).To(HaveLen(1)) Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) }) It("gets multiple ACK ranges", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) - hist.ReceivedPacket(6) - hist.ReceivedPacket(1) - hist.ReceivedPacket(11) - hist.ReceivedPacket(10) - hist.ReceivedPacket(2) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(1)).To(BeTrue()) + Expect(hist.ReceivedPacket(11)).To(BeTrue()) + Expect(hist.ReceivedPacket(10)).To(BeTrue()) + Expect(hist.ReceivedPacket(2)).To(BeTrue()) ackRanges := hist.GetAckRanges() Expect(ackRanges).To(HaveLen(3)) Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11})) @@ -234,15 +235,15 @@ var _ = Describe("receivedPacketHistory", func() { }) It("gets a single ACK range", func() { - hist.ReceivedPacket(4) - hist.ReceivedPacket(5) + Expect(hist.ReceivedPacket(4)).To(BeTrue()) + Expect(hist.ReceivedPacket(5)).To(BeTrue()) Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) }) It("gets the highest of multiple ACK ranges", func() { - hist.ReceivedPacket(3) - hist.ReceivedPacket(6) - hist.ReceivedPacket(7) + Expect(hist.ReceivedPacket(3)).To(BeTrue()) + Expect(hist.ReceivedPacket(6)).To(BeTrue()) + Expect(hist.ReceivedPacket(7)).To(BeTrue()) Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7})) }) }) diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 5cd973e33..b2c060842 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -19,9 +19,11 @@ type receivedPacketTracker struct { maxAckDelay time.Duration rttStats *congestion.RTTStats + hasNewAck bool // true as soon as we received an ack-eliciting new packet + ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets + packetsReceivedSinceLastAck int ackElicitingPacketsReceivedSinceLastAck int - ackQueued bool ackAlarm time.Time lastAck *wire.AckFrame @@ -55,7 +57,9 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe h.largestObservedReceivedTime = rcvTime } - h.packetHistory.ReceivedPacket(packetNumber) + if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck { + h.hasNewAck = true + } h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing) } @@ -96,7 +100,9 @@ func (h *receivedPacketTracker) maybeQueueAck(packetNumber protocol.PacketNumber // always ack the first packet if h.lastAck == nil { - h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + if !h.ackQueued { + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + } h.ackQueued = true return } @@ -163,13 +169,18 @@ func (h *receivedPacketTracker) maybeQueueAck(packetNumber protocol.PacketNumber } } -func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { - now := time.Now() - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { +func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { + if !h.hasNewAck { return nil } - if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { - h.logger.Debugf("Sending ACK because the ACK timer expired.") + now := time.Now() + if onlyIfQueued { + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + return nil + } + if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + h.logger.Debugf("Sending ACK because the ACK timer expired.") + } } ack := &wire.AckFrame{ @@ -182,6 +193,7 @@ func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { h.lastAck = ack h.ackAlarm = time.Time{} h.ackQueued = false + h.hasNewAck = false h.packetsReceivedSinceLastAck = 0 h.ackElicitingPacketsReceivedSinceLastAck = 0 return ack diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 87262c4ee..0d6f22271 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -55,7 +55,7 @@ var _ = Describe("Received Packet Tracker", func() { for i := 1; i <= 10; i++ { tracker.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) } - Expect(tracker.GetAckFrame()).ToNot(BeNil()) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) Expect(tracker.ackQueued).To(BeFalse()) } @@ -63,22 +63,22 @@ var _ = Describe("Received Packet Tracker", func() { for i := 1; i <= minReceivedBeforeAckDecimation; i++ { tracker.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) } - Expect(tracker.GetAckFrame()).ToNot(BeNil()) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) Expect(tracker.ackQueued).To(BeFalse()) } It("always queues an ACK for the first packet", func() { - tracker.ReceivedPacket(1, time.Now(), false) + tracker.ReceivedPacket(1, time.Now(), true) Expect(tracker.ackQueued).To(BeTrue()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.GetAckFrame().DelayTime).To(BeNumerically("~", 0, time.Second)) + Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) }) It("works with packet number 0", func() { - tracker.ReceivedPacket(0, time.Now(), false) + tracker.ReceivedPacket(0, time.Now(), true) Expect(tracker.ackQueued).To(BeTrue()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.GetAckFrame().DelayTime).To(BeNumerically("~", 0, time.Second)) + Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) }) It("queues an ACK for every second ack-eliciting packet at the beginning", func() { @@ -92,10 +92,21 @@ var _ = Describe("Received Packet Tracker", func() { Expect(tracker.ackQueued).To(BeTrue()) p++ // dequeue the ACK frame - Expect(tracker.GetAckFrame()).ToNot(BeNil()) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) } }) + It("resets the counter when a non-queued ACK frame is generated", func() { + receiveAndAck10Packets() + rcvTime := time.Now() + tracker.ReceivedPacket(11, rcvTime, true) + Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) + tracker.ReceivedPacket(12, rcvTime, true) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + tracker.ReceivedPacket(13, rcvTime, true) + Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) + }) + It("queues an ACK for every 10 ack-eliciting packet, if they are arriving fast", func() { receiveAndAck10Packets() p := protocol.PacketNumber(10000) @@ -125,7 +136,7 @@ var _ = Describe("Received Packet Tracker", func() { receiveAndAck10Packets() tracker.ReceivedPacket(11, time.Time{}, true) tracker.ReceivedPacket(13, time.Time{}, true) - ack := tracker.GetAckFrame() // ACK: 1-11 and 13, missing: 12 + ack := tracker.GetAckFrame(true) // ACK: 1-11 and 13, missing: 12 Expect(ack).ToNot(BeNil()) Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(tracker.ackQueued).To(BeFalse()) @@ -138,12 +149,12 @@ var _ = Describe("Received Packet Tracker", func() { // 11 is missing tracker.ReceivedPacket(12, time.Time{}, true) tracker.ReceivedPacket(13, time.Time{}, true) - ack := tracker.GetAckFrame() // ACK: 1-10, 12-13 + ack := tracker.GetAckFrame(true) // ACK: 1-10, 12-13 Expect(ack).ToNot(BeNil()) // now receive 11 tracker.IgnoreBelow(12) tracker.ReceivedPacket(11, time.Time{}, false) - ack = tracker.GetAckFrame() + ack = tracker.GetAckFrame(true) Expect(ack).To(BeNil()) }) @@ -169,151 +180,180 @@ var _ = Describe("Received Packet Tracker", func() { tracker.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9 Expect(rttStats.MinRTT()).To(Equal(rtt)) Expect(tracker.ackAlarm.Sub(now)).To(Equal(rtt / 8)) - ack := tracker.GetAckFrame() + ack := tracker.GetAckFrame(true) Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(ack).ToNot(BeNil()) }) }) Context("ACK generation", func() { - BeforeEach(func() { - tracker.ackQueued = true - }) - - It("generates a simple ACK frame", func() { + It("generates an ACK for an ack-eliciting packet, if no ACK is queued yet", func() { tracker.ReceivedPacket(1, time.Time{}, true) + // The first packet is always acknowledged. + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + tracker.ReceivedPacket(2, time.Time{}, true) - ack := tracker.GetAckFrame() + Expect(tracker.GetAckFrame(true)).To(BeNil()) + ack := tracker.GetAckFrame(false) Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.HasMissingRanges()).To(BeFalse()) }) - It("generates an ACK for packet number 0", func() { - tracker.ReceivedPacket(0, time.Time{}, true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) - - It("sets the delay time", func() { + It("doesn't generate ACK for a non-ack-eliciting packet, if no ACK is queued yet", func() { tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond)) - }) + // The first packet is always acknowledged. + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - It("uses a 0 delay time if the delay would be negative", func() { - tracker.ReceivedPacket(0, time.Now().Add(time.Hour), true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.DelayTime).To(BeZero()) - }) - - It("saves the last sent ACK", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(tracker.lastAck).To(Equal(ack)) - tracker.ReceivedPacket(2, time.Time{}, true) - tracker.ackQueued = true - ack = tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(tracker.lastAck).To(Equal(ack)) - }) - - It("generates an ACK frame with missing packets", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ReceivedPacket(4, time.Time{}, true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) - Expect(ack.AckRanges).To(Equal([]wire.AckRange{ - {Smallest: 4, Largest: 4}, - {Smallest: 1, Largest: 1}, - })) - }) - - It("generates an ACK for packet number 0 and other packets", func() { - tracker.ReceivedPacket(0, time.Time{}, true) - tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ReceivedPacket(2, time.Time{}, false) + Expect(tracker.GetAckFrame(false)).To(BeNil()) tracker.ReceivedPacket(3, time.Time{}, true) - ack := tracker.GetAckFrame() + ack := tracker.GetAckFrame(false) Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) - Expect(ack.AckRanges).To(Equal([]wire.AckRange{ - {Smallest: 3, Largest: 3}, - {Smallest: 0, Largest: 1}, - })) }) - It("doesn't add delayed packets to the packetHistory", func() { - tracker.IgnoreBelow(7) - tracker.ReceivedPacket(4, time.Time{}, true) - tracker.ReceivedPacket(10, time.Time{}, true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10))) - }) + Context("for queued ACKs", func() { + BeforeEach(func() { + tracker.ackQueued = true + }) - It("deletes packets from the packetHistory when a lower limit is set", func() { - for i := 1; i <= 12; i++ { - tracker.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) - } - tracker.IgnoreBelow(7) - // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12))) - Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7))) - Expect(ack.HasMissingRanges()).To(BeFalse()) - }) + It("generates a simple ACK frame", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ReceivedPacket(2, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) - // TODO: remove this test when dropping support for STOP_WAITINGs - It("handles a lower limit of 0", func() { - tracker.IgnoreBelow(0) - tracker.ReceivedPacket(1337, time.Time{}, true) - ack := tracker.GetAckFrame() - Expect(ack).ToNot(BeNil()) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337))) - }) + It("generates an ACK for packet number 0", func() { + tracker.ReceivedPacket(0, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) - It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ackAlarm = time.Now().Add(-time.Minute) - Expect(tracker.GetAckFrame()).ToNot(BeNil()) - Expect(tracker.packetsReceivedSinceLastAck).To(BeZero()) - Expect(tracker.GetAlarmTimeout()).To(BeZero()) - Expect(tracker.ackElicitingPacketsReceivedSinceLastAck).To(BeZero()) - Expect(tracker.ackQueued).To(BeFalse()) - }) + It("sets the delay time", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond)) + }) - It("doesn't generate an ACK when none is queued and the timer is not set", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ackQueued = false - tracker.ackAlarm = time.Time{} - Expect(tracker.GetAckFrame()).To(BeNil()) - }) + It("uses a 0 delay time if the delay would be negative", func() { + tracker.ReceivedPacket(0, time.Now().Add(time.Hour), true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.DelayTime).To(BeZero()) + }) - It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ackQueued = false - tracker.ackAlarm = time.Now().Add(time.Minute) - Expect(tracker.GetAckFrame()).To(BeNil()) - }) + It("saves the last sent ACK", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(tracker.lastAck).To(Equal(ack)) + tracker.ReceivedPacket(2, time.Time{}, true) + tracker.ackQueued = true + ack = tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(tracker.lastAck).To(Equal(ack)) + }) - It("generates an ACK when the timer has expired", func() { - tracker.ReceivedPacket(1, time.Time{}, true) - tracker.ackQueued = false - tracker.ackAlarm = time.Now().Add(-time.Minute) - Expect(tracker.GetAckFrame()).ToNot(BeNil()) + It("generates an ACK frame with missing packets", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ReceivedPacket(4, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + {Smallest: 4, Largest: 4}, + {Smallest: 1, Largest: 1}, + })) + }) + + It("generates an ACK for packet number 0 and other packets", func() { + tracker.ReceivedPacket(0, time.Time{}, true) + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ReceivedPacket(3, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0))) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{ + {Smallest: 3, Largest: 3}, + {Smallest: 0, Largest: 1}, + })) + }) + + It("doesn't add delayed packets to the packetHistory", func() { + tracker.IgnoreBelow(7) + tracker.ReceivedPacket(4, time.Time{}, true) + tracker.ReceivedPacket(10, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10))) + }) + + It("deletes packets from the packetHistory when a lower limit is set", func() { + for i := 1; i <= 12; i++ { + tracker.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) + } + tracker.IgnoreBelow(7) + // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7))) + Expect(ack.HasMissingRanges()).To(BeFalse()) + }) + + // TODO: remove this test when dropping support for STOP_WAITINGs + It("handles a lower limit of 0", func() { + tracker.IgnoreBelow(0) + tracker.ReceivedPacket(1337, time.Time{}, true) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337))) + }) + + It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ackAlarm = time.Now().Add(-time.Minute) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + Expect(tracker.packetsReceivedSinceLastAck).To(BeZero()) + Expect(tracker.GetAlarmTimeout()).To(BeZero()) + Expect(tracker.ackElicitingPacketsReceivedSinceLastAck).To(BeZero()) + Expect(tracker.ackQueued).To(BeFalse()) + }) + + It("doesn't generate an ACK when none is queued and the timer is not set", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ackQueued = false + tracker.ackAlarm = time.Time{} + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ackQueued = false + tracker.ackAlarm = time.Now().Add(time.Minute) + Expect(tracker.GetAckFrame(true)).To(BeNil()) + }) + + It("generates an ACK when the timer has expired", func() { + tracker.ReceivedPacket(1, time.Time{}, true) + tracker.ackQueued = false + tracker.ackAlarm = time.Now().Add(-time.Minute) + Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) + }) }) }) }) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 064c62b96..6362cec17 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -49,17 +49,17 @@ func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) * } // GetAckFrame mocks base method -func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { +func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAckFrame", arg0) + ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) ret0, _ := ret[0].(*wire.AckFrame) return ret0 } // GetAckFrame indicates an expected call of GetAckFrame -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) } // GetAlarmTimeout mocks base method diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index c088a947a..528719a08 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -36,15 +36,15 @@ func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { } // GetAckFrame mocks base method -func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { +func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 bool) *wire.AckFrame { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAckFrame", arg0) + ret := m.ctrl.Call(m, "GetAckFrame", arg0, arg1) ret0, _ := ret[0].(*wire.AckFrame) return ret0 } // GetAckFrame indicates an expected call of GetAckFrame -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call { +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, arg1) } diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 24693286a..56407b834 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -64,3 +64,17 @@ func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{} mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1) } + +// HasData mocks base method +func (m *MockFrameSource) HasData() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasData indicates an expected call of HasData +func (mr *MockFrameSourceMockRecorder) HasData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) +} diff --git a/packet_packer.go b/packet_packer.go index 92a771c19..6606e4471 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -135,12 +135,13 @@ type sealingManager interface { } type frameSource interface { + HasData() bool AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) } type ackFrameSource interface { - GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame } type packetPacker struct { @@ -279,18 +280,18 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke var encLevel protocol.EncryptionLevel var ack *wire.AckFrame if !handshakeConfirmed { - ack = p.acks.GetAckFrame(protocol.EncryptionInitial) + ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) if ack != nil { encLevel = protocol.EncryptionInitial } else { - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) + ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) if ack != nil { encLevel = protocol.EncryptionHandshake } } } if ack == nil { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT) + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) if ack == nil { return nil, nil } @@ -431,11 +432,12 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSi } } + hasData := s.HasData() var ack *wire.AckFrame if encLevel != protocol.EncryptionHandshake || buffer.Len() == 0 { - ack = p.acks.GetAckFrame(encLevel) + ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) } - if !s.HasData() && !hasRetransmission && ack == nil { + if !hasData && !hasRetransmission && ack == nil { // nothing to send return nil, nil } @@ -499,7 +501,7 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketS headerLen := header.GetLength(p.version) maxSize := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen - payload := p.composeNextPacket(maxSize, encLevel != protocol.Encryption0RTT && buffer.Len() == 0) + payload := p.composeNextPacket(maxSize, encLevel == protocol.Encryption1RTT && buffer.Len() == 0) // check if we have anything to send if len(payload.frames) == 0 && payload.ack == nil { @@ -523,35 +525,44 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketS func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) payload { var payload payload - var ack *wire.AckFrame + hasData := p.framer.HasData() + hasRetransmission := p.retransmissionQueue.HasAppData() if ackAllowed { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT) + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) if ack != nil { payload.ack = ack payload.length += ack.Length(p.version) } } - for { - remainingLen := maxFrameSize - payload.length - if remainingLen < protocol.MinStreamFrameSize { - break - } - f := p.retransmissionQueue.GetAppDataFrame(remainingLen) - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - payload.length += f.Length(p.version) + if ack == nil && !hasData && !hasRetransmission { + return payload } - var lengthAdded protocol.ByteCount - payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded + if hasRetransmission { + for { + remainingLen := maxFrameSize - payload.length + if remainingLen < protocol.MinStreamFrameSize { + break + } + f := p.retransmissionQueue.GetAppDataFrame(remainingLen) + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + payload.length += f.Length(p.version) + } + } - payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded + if hasData { + var lengthAdded protocol.ByteCount + payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + + payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + } return payload } diff --git a/packet_packer_test.go b/packet_packer_test.go index 69aacaba2..97fa328a9 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -167,12 +167,13 @@ var _ = Describe("Packet packer", func() { } }), ) + framer.EXPECT().HasData().Return(true) sealingManager.EXPECT().GetInitialSealer().Return(nil, nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(ackhandler.Frame{Frame: f}) @@ -205,9 +206,9 @@ var _ = Describe("Packet packer", func() { Context("packing ACK packets", func() { It("doesn't pack a packet if there's no ACK to send", func() { - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) p, err := packer.MaybePackAckPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -218,8 +219,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).Return(ack) p, err := packer.MaybePackAckPacket(false) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -232,7 +233,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) p, err := packer.MaybePackAckPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -248,10 +249,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable).AnyTimes() initialStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).AnyTimes() }) It("packs a 0-RTT packet", func() { @@ -259,6 +260,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{ByteOffset: 0x1337}} + framer.EXPECT().HasData().Return(true) framer.EXPECT().AppendControlFrames(nil, gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return append(frames, cf), cf.Length(packer.version) }) @@ -422,20 +424,12 @@ var _ = Describe("Packet packer", func() { }) Context("packing normal packets", func() { - BeforeEach(func() { - initialStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() - handshakeStream.EXPECT().HasData().AnyTimes() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).AnyTimes() - }) - It("returns nil when no packet is queued", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) // don't expect any calls to PopPacketNumber sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - framer.EXPECT().AppendControlFrames(nil, gomock.Any()) - framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) + framer.EXPECT().HasData() p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) @@ -445,7 +439,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() f := &wire.StreamFrame{ StreamID: 5, @@ -465,7 +460,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{ StreamID: 5, @@ -480,10 +476,9 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) + framer.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - expectAppendControlFrames() - expectAppendStreamFrames() p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -494,7 +489,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) frames := []ackhandler.Frame{ {Frame: &wire.ResetStreamFrame{}}, {Frame: &wire.MaxDataFrame{}}, @@ -511,7 +507,8 @@ var _ = Describe("Packet packer", func() { It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) var maxSize protocol.ByteCount gomock.InOrder( framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { @@ -535,6 +532,8 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + handshakeStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) packet, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) @@ -573,7 +572,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealer := getSealer() sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) packet, err := packer.PackPacket() @@ -622,7 +622,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) p, err := packer.PackPacket() @@ -640,7 +641,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() @@ -656,7 +658,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() @@ -667,7 +670,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err = packer.PackPacket() @@ -682,9 +686,10 @@ var _ = Describe("Packet packer", func() { // nothing to send pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) expectAppendControlFrames() expectAppendStreamFrames() - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -694,8 +699,9 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) @@ -707,7 +713,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) p, err := packer.PackPacket() @@ -721,7 +728,8 @@ var _ = Describe("Packet packer", func() { It("sets the maximum packet size", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) + framer.EXPECT().HasData().Return(true).Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen @@ -746,7 +754,8 @@ var _ = Describe("Packet packer", func() { It("doesn't increase the max packet size", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) + framer.EXPECT().HasData().Return(true).Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen @@ -778,7 +787,7 @@ var _ = Describe("Packet packer", func() { Offset: 0x1337, Data: []byte("foobar"), } - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData().Return(true).AnyTimes() initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) @@ -795,8 +804,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData().Return(true).Times(2) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { @@ -822,7 +831,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) // don't EXPECT any calls for a Handshake ACK frame initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { @@ -860,7 +869,8 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) // don't EXPECT any calls for a Handshake ACK frame initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { @@ -895,7 +905,8 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) // don't EXPECT any calls for a 1-RTT ACK frame handshakeStream.EXPECT().HasData().Return(true).Times(2) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { @@ -926,7 +937,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { s := size - protocol.MinCoalescedPacketSize @@ -955,7 +966,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(s protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: 0x1337} @@ -979,7 +990,8 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) oneRTTSealer := getSealer() sealingManager.EXPECT().Get1RTTSealer().Return(oneRTTSealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + framer.EXPECT().HasData().Return(true) handshakeStream.EXPECT().HasData().Return(true).Times(2) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ Offset: 0x1337, @@ -999,6 +1011,44 @@ var _ = Describe("Packet packer", func() { Expect(appDataSize).To(Equal(size - p.packets[0].length - p.packets[1].header.GetLength(packer.version) - protocol.ByteCount(oneRTTSealer.Overhead()))) }) + It("pads if payload length + packet number length is smaller than 4, for Long Header packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealer := getSealer() + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + handshakeStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + packet, err := packer.PackCoalescedPacket(protocol.MaxByteCount) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.packets).To(HaveLen(1)) + // cut off the tag that the mock sealer added + // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.buffer.Data) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */ + sealer.Overhead())) + // the first bytes of the payload should be a 2 PADDING frames... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + secondPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(secondPayloadByte).To(Equal(byte(0))) + // ... followed by the PING + frameParser := wire.NewFrameParser(packer.version) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(r.Len()).To(Equal(sealer.Overhead())) + }) + It("adds retransmissions", func() { f := &wire.CryptoFrame{Data: []byte("Initial")} retransmissionQueue.AddInitial(f) @@ -1008,7 +1058,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData() p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) @@ -1021,7 +1071,7 @@ var _ = Describe("Packet packer", func() { It("sends an Initial packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) initialStream.EXPECT().HasData().Times(2) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) @@ -1039,7 +1089,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -1047,8 +1097,8 @@ var _ = Describe("Packet packer", func() { It("sends a Handshake packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true).Return(ack) initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData().Times(2) sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) @@ -1072,7 +1122,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient @@ -1095,7 +1145,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false).Return(ack) initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS @@ -1114,7 +1164,7 @@ var _ = Describe("Packet packer", func() { f := &wire.CryptoFrame{Data: []byte("Initial")} retransmissionQueue.AddInitial(f) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) @@ -1134,7 +1184,7 @@ var _ = Describe("Packet packer", func() { f := &wire.CryptoFrame{Data: []byte("Initial")} retransmissionQueue.AddInitial(f) sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) @@ -1152,7 +1202,7 @@ var _ = Describe("Packet packer", func() { f := &wire.CryptoFrame{Data: []byte("Handshake")} retransmissionQueue.AddHandshake(f) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) handshakeStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) @@ -1170,9 +1220,10 @@ var _ = Describe("Packet packer", func() { f := &wire.StreamFrame{Data: []byte("1-RTT")} retransmissionQueue.AddInitial(f) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + framer.EXPECT().HasData().Return(true) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) @@ -1186,10 +1237,9 @@ var _ = Describe("Packet packer", func() { It("returns nil if there's no probe data to send", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - expectAppendControlFrames() - expectAppendStreamFrames() + framer.EXPECT().HasData() packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) diff --git a/retransmission_queue.go b/retransmission_queue.go index d29d3660b..404b72b79 100644 --- a/retransmission_queue.go +++ b/retransmission_queue.go @@ -47,6 +47,10 @@ func (q *retransmissionQueue) HasHandshakeData() bool { return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 } +func (q *retransmissionQueue) HasAppData() bool { + return len(q.appData) > 0 +} + func (q *retransmissionQueue) AddAppData(f wire.Frame) { if _, ok := f.(*wire.StreamFrame); ok { panic("STREAM frames are handled with their respective streams.") diff --git a/retransmission_queue_test.go b/retransmission_queue_test.go index c9e940f01..213f732fb 100644 --- a/retransmission_queue_test.go +++ b/retransmission_queue_test.go @@ -176,9 +176,12 @@ var _ = Describe("Retransmission queue", func() { It("queues and retrieves a control frame", func() { f := &wire.MaxDataFrame{ByteOffset: 0x42} + Expect(q.HasAppData()).To(BeFalse()) q.AddAppData(f) + Expect(q.HasAppData()).To(BeTrue()) Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasAppData()).To(BeFalse()) }) }) })