diff --git a/packet_packer.go b/packet_packer.go index ff9831ab4..b1dea6f26 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -91,7 +91,7 @@ func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueu p.frames[i].OnLost = q.AddInitial case protocol.EncryptionHandshake: p.frames[i].OnLost = q.AddHandshake - case protocol.Encryption1RTT: + case protocol.Encryption0RTT, protocol.Encryption1RTT: p.frames[i].OnLost = q.AddAppData } } diff --git a/packet_packer_test.go b/packet_packer_test.go index ca25a92a4..6c86fa5f1 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -16,7 +16,9 @@ import ( mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) @@ -1278,19 +1280,26 @@ var _ = Describe("Converting to AckHandler packets", func() { Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) }) - It("doesn't overwrite the OnLost callback, if it is set", func() { - var pingLost bool - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeHandshake}}, - frames: []ackhandler.Frame{ - {Frame: &wire.MaxDataFrame{}}, - {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, - }, - } - p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0].OnLost).ToNot(BeNil()) - p.Frames[1].OnLost(nil) - Expect(pingLost).To(BeTrue()) - }) + DescribeTable( + "doesn't overwrite the OnLost callback, if it is set", + func(hdr wire.Header) { + var pingLost bool + packet := &packetContents{ + header: &wire.ExtendedHeader{Header: hdr}, + frames: []ackhandler.Frame{ + {Frame: &wire.MaxDataFrame{}}, + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, + }, + } + p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames[0].OnLost).ToNot(BeNil()) + p.Frames[1].OnLost(nil) + Expect(pingLost).To(BeTrue()) + }, + Entry(protocol.EncryptionInitial.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeInitial}), + Entry(protocol.EncryptionHandshake.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}), + Entry(protocol.Encryption0RTT.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketType0RTT}), + Entry(protocol.Encryption1RTT.String(), wire.Header{}), + ) })