diff --git a/ackhandler/retransmittable.go b/ackhandler/retransmittable.go new file mode 100644 index 00000000..17437b8c --- /dev/null +++ b/ackhandler/retransmittable.go @@ -0,0 +1,38 @@ +package ackhandler + +import ( + "github.com/lucas-clemente/quic-go/frames" +) + +// Returns a new slice with all non-retransmittable frames deleted. +func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame { + res := make([]frames.Frame, 0, len(fs)) + for _, f := range fs { + if IsFrameRetransmittable(f) { + res = append(res, f) + } + } + return res +} + +// IsFrameRetransmittable returns true if the frame should be retransmitted. +func IsFrameRetransmittable(f frames.Frame) bool { + switch f.(type) { + case *frames.StopWaitingFrame: + return false + case *frames.AckFrame: + return false + default: + return true + } +} + +// HasRetransmittableFrames returns true if at least one frame is retransmittable. +func HasRetransmittableFrames(fs []frames.Frame) bool { + for _, f := range fs { + if IsFrameRetransmittable(f) { + return true + } + } + return false +} diff --git a/ackhandler/retransmittable_test.go b/ackhandler/retransmittable_test.go new file mode 100644 index 00000000..4a5ea858 --- /dev/null +++ b/ackhandler/retransmittable_test.go @@ -0,0 +1,44 @@ +package ackhandler + +import ( + "reflect" + + "github.com/lucas-clemente/quic-go/frames" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("retransmittable frames", func() { + for fl, el := range map[frames.Frame]bool{ + &frames.AckFrame{}: false, + &frames.StopWaitingFrame{}: false, + &frames.BlockedFrame{}: true, + &frames.ConnectionCloseFrame{}: true, + &frames.GoawayFrame{}: true, + &frames.PingFrame{}: true, + &frames.RstStreamFrame{}: true, + &frames.StreamFrame{}: true, + &frames.WindowUpdateFrame{}: true, + } { + f := fl + e := el + fName := reflect.ValueOf(f).Elem().Type().Name() + + It("works for "+fName, func() { + Expect(IsFrameRetransmittable(f)).To(Equal(e)) + }) + + It("stripping non-retransmittable frames works for "+fName, func() { + s := []frames.Frame{f} + if e { + Expect(stripNonRetransmittableFrames(s)).To(Equal([]frames.Frame{f})) + } else { + Expect(stripNonRetransmittableFrames(s)).To(BeEmpty()) + } + }) + + It("HasRetransmittableFrames works for "+fName, func() { + Expect(HasRetransmittableFrames([]frames.Frame{f})).To(Equal(e)) + }) + } +}) diff --git a/packet_unpacker.go b/packet_unpacker.go index 30dee80a..c92e6a53 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -10,6 +10,11 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) +type unpackedPacket struct { + encryptionLevel protocol.EncryptionLevel + frames []frames.Frame +} + type quicAEAD interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) } diff --git a/session.go b/session.go index 6e187a7e..0bebe993 100644 --- a/session.go +++ b/session.go @@ -405,7 +405,8 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) - if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable()); err != nil { + isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) + if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil { return err } diff --git a/unpacked_packet.go b/unpacked_packet.go deleted file mode 100644 index 0636b8f1..00000000 --- a/unpacked_packet.go +++ /dev/null @@ -1,31 +0,0 @@ -package quic - -import ( - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" -) - -type unpackedPacket struct { - encryptionLevel protocol.EncryptionLevel - frames []frames.Frame -} - -func (u *unpackedPacket) IsRetransmittable() bool { - for _, f := range u.frames { - switch f.(type) { - case *frames.StreamFrame: - return true - case *frames.RstStreamFrame: - return true - case *frames.WindowUpdateFrame: - return true - case *frames.BlockedFrame: - return true - case *frames.PingFrame: - return true - case *frames.GoawayFrame: - return true - } - } - return false -} diff --git a/unpacked_packet_test.go b/unpacked_packet_test.go deleted file mode 100644 index 82112a26..00000000 --- a/unpacked_packet_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package quic - -import ( - "github.com/lucas-clemente/quic-go/frames" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Unpacked packet", func() { - var packet *unpackedPacket - BeforeEach(func() { - packet = &unpackedPacket{} - }) - - It("says that an empty packet is not retransmittable", func() { - Expect(packet.IsRetransmittable()).To(BeFalse()) - }) - - It("detects the frame types", func() { - packet.frames = []frames.Frame{&frames.AckFrame{}} - Expect(packet.IsRetransmittable()).To(BeFalse()) - packet.frames = []frames.Frame{&frames.BlockedFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - packet.frames = []frames.Frame{&frames.GoawayFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - packet.frames = []frames.Frame{&frames.PingFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - packet.frames = []frames.Frame{&frames.StreamFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - packet.frames = []frames.Frame{&frames.RstStreamFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - packet.frames = []frames.Frame{&frames.StopWaitingFrame{}} - Expect(packet.IsRetransmittable()).To(BeFalse()) - packet.frames = []frames.Frame{&frames.WindowUpdateFrame{}} - Expect(packet.IsRetransmittable()).To(BeTrue()) - }) - - It("says that a packet is retransmittable if it contains one retransmittable frame", func() { - packet.frames = []frames.Frame{ - &frames.AckFrame{}, - &frames.WindowUpdateFrame{}, - &frames.StopWaitingFrame{}, - } - Expect(packet.IsRetransmittable()).To(BeTrue()) - }) -})