diff --git a/ackhandler/packet.go b/ackhandler/packet.go index 6fe567c6..d748547e 100644 --- a/ackhandler/packet.go +++ b/ackhandler/packet.go @@ -19,31 +19,17 @@ type Packet struct { SendTime time.Time } -// GetStreamFramesForRetransmission gets all the streamframes for retransmission -func (p *Packet) GetStreamFramesForRetransmission() []*frames.StreamFrame { - var streamFrames []*frames.StreamFrame +// GetFramesForRetransmission gets all the frames for retransmission +func (p *Packet) GetFramesForRetransmission() []frames.Frame { + var fs []frames.Frame for _, frame := range p.Frames { - if streamFrame, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { - streamFrames = append(streamFrames, streamFrame) - } - } - return streamFrames -} - -// GetControlFramesForRetransmission gets all the control frames for retransmission -func (p *Packet) GetControlFramesForRetransmission() []frames.Frame { - var controlFrames []frames.Frame - for _, frame := range p.Frames { - // omit ACKs - if _, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { + switch frame.(type) { + case *frames.AckFrame: + continue + case *frames.StopWaitingFrame: continue } - - _, isAck := frame.(*frames.AckFrame) - _, isStopWaiting := frame.(*frames.StopWaitingFrame) - if !isAck && !isStopWaiting { - controlFrames = append(controlFrames, frame) - } + fs = append(fs, frame) } - return controlFrames + return fs } diff --git a/ackhandler/packet_test.go b/ackhandler/packet_test.go index f713bf86..32452a32 100644 --- a/ackhandler/packet_test.go +++ b/ackhandler/packet_test.go @@ -7,88 +7,45 @@ import ( ) var _ = Describe("Packet", func() { - Context("getFramesForRetransmission", func() { - var packet Packet - var streamFrame1, streamFrame2 *frames.StreamFrame - var ackFrame1, ackFrame2 *frames.AckFrame - var stopWaitingFrame *frames.StopWaitingFrame - var rstStreamFrame *frames.RstStreamFrame - var windowUpdateFrame *frames.WindowUpdateFrame + Context("getting frames for retransmission", func() { + ackFrame := &frames.AckFrame{LargestAcked: 13} + stopWaitingFrame := &frames.StopWaitingFrame{LeastUnacked: 7331} + windowUpdateFrame := &frames.WindowUpdateFrame{StreamID: 999} - BeforeEach(func() { - streamFrame1 = &frames.StreamFrame{ - StreamID: 5, - Data: []byte{0x13, 0x37}, - } - streamFrame2 = &frames.StreamFrame{ - StreamID: 6, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - ackFrame1 = &frames.AckFrame{ - LargestAcked: 13, - } - ackFrame2 = &frames.AckFrame{ - LargestAcked: 333, - } - rstStreamFrame = &frames.RstStreamFrame{ - StreamID: 555, - ErrorCode: 1337, - } - stopWaitingFrame = &frames.StopWaitingFrame{ - LeastUnacked: 7331, - } - windowUpdateFrame = &frames.WindowUpdateFrame{ - StreamID: 999, - } - packet = Packet{ - PacketNumber: 1337, - Frames: []frames.Frame{windowUpdateFrame, streamFrame1, ackFrame1, streamFrame2, rstStreamFrame, ackFrame2, stopWaitingFrame}, + streamFrame := &frames.StreamFrame{ + StreamID: 5, + Data: []byte{0x13, 0x37}, + } + + rstStreamFrame := &frames.RstStreamFrame{ + StreamID: 555, + ErrorCode: 1337, + } + + It("returns nil if there are no retransmittable frames", func() { + packet := &Packet{ + Frames: []frames.Frame{ackFrame, stopWaitingFrame}, } + Expect(packet.GetFramesForRetransmission()).To(BeNil()) }) - It("gets all StreamFrames", func() { - streamFrames := packet.GetStreamFramesForRetransmission() - Expect(streamFrames).To(HaveLen(2)) - Expect(streamFrames).To(ContainElement(streamFrame1)) - Expect(streamFrames).To(ContainElement(streamFrame2)) - }) - - It("gets all control frames", func() { - controlFrames := packet.GetControlFramesForRetransmission() - Expect(controlFrames).To(HaveLen(2)) - Expect(controlFrames).To(ContainElement(rstStreamFrame)) - Expect(controlFrames).To(ContainElement(windowUpdateFrame)) - }) - - It("does not return any ACK frames", func() { - controlFrames := packet.GetControlFramesForRetransmission() - Expect(controlFrames).ToNot(ContainElement(ackFrame1)) - Expect(controlFrames).ToNot(ContainElement(ackFrame2)) - }) - - It("does not return any ACK frames", func() { - controlFrames := packet.GetControlFramesForRetransmission() - Expect(controlFrames).ToNot(ContainElement(stopWaitingFrame)) - }) - - It("returns an empty slice of StreamFrames if no StreamFrames are queued", func() { - // overwrite the globally defined packet here - packet := Packet{ - PacketNumber: 1337, - Frames: []frames.Frame{ackFrame1, rstStreamFrame}, + It("returns all retransmittable frames", func() { + packet := &Packet{ + Frames: []frames.Frame{ + windowUpdateFrame, + ackFrame, + stopWaitingFrame, + streamFrame, + rstStreamFrame, + }, } - streamFrames := packet.GetStreamFramesForRetransmission() - Expect(streamFrames).To(BeEmpty()) + fs := packet.GetFramesForRetransmission() + Expect(fs).To(ContainElement(streamFrame)) + Expect(fs).To(ContainElement(rstStreamFrame)) + Expect(fs).To(ContainElement(windowUpdateFrame)) + Expect(fs).ToNot(ContainElement(stopWaitingFrame)) + Expect(fs).ToNot(ContainElement(ackFrame)) }) - It("returns an empty slice of control frames if no applicable control frames are queued", func() { - // overwrite the globally defined packet here - packet := Packet{ - PacketNumber: 1337, - Frames: []frames.Frame{streamFrame1, ackFrame1, stopWaitingFrame}, - } - controlFrames := packet.GetControlFramesForRetransmission() - Expect(controlFrames).To(BeEmpty()) - }) }) }) diff --git a/session.go b/session.go index 21295e5b..cefe6369 100644 --- a/session.go +++ b/session.go @@ -497,9 +497,13 @@ func (s *Session) sendPacket() error { utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) // resend the frames that were in the packet - controlFrames = append(controlFrames, retransmitPacket.GetControlFramesForRetransmission()...) - for _, streamFrame := range retransmitPacket.GetStreamFramesForRetransmission() { - s.streamFramer.AddFrameForRetransmission(streamFrame) + for _, frame := range retransmitPacket.GetFramesForRetransmission() { + switch frame.(type) { + case *frames.StreamFrame: + s.streamFramer.AddFrameForRetransmission(frame.(*frames.StreamFrame)) + default: + controlFrames = append(controlFrames, frame) + } } }