diff --git a/packet_packer.go b/packet_packer.go index cf919409..3664b00c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -30,6 +30,7 @@ type packetPacker struct { controlFrames []frames.Frame stopWaiting *frames.StopWaitingFrame + ackFrame *frames.AckFrame leastUnacked protocol.PacketNumber } @@ -65,16 +66,20 @@ func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*p }, err } -func (p *packetPacker) PackAckPacket(ackframe *frames.AckFrame) (*packedPacket, error) { +func (p *packetPacker) PackAckPacket() (*packedPacket, error) { + if p.ackFrame == nil { + return nil, errors.New("packet packer BUG: no ack frame queued") + } encLevel, sealer := p.cryptoSetup.GetSealer() ph := p.getPublicHeader(encLevel) - frames := []frames.Frame{ackframe} + frames := []frames.Frame{p.ackFrame} if p.stopWaiting != nil { p.stopWaiting.PacketNumber = ph.PacketNumber p.stopWaiting.PacketNumberLen = ph.PacketNumberLen frames = append(frames, p.stopWaiting) p.stopWaiting = nil } + p.ackFrame = nil raw, err := p.writeAndSealPacket(ph, frames, sealer) return &packedPacket{ number: ph.PacketNumber, @@ -84,8 +89,8 @@ func (p *packetPacker) PackAckPacket(ackframe *frames.AckFrame) (*packedPacket, }, err } -// RetransmitNonForwardSecurePacket retransmits a handshake packet, that was sent with less than forward-secure encryption -func (p *packetPacker) RetransmitNonForwardSecurePacket(packet *ackhandler.Packet) (*packedPacket, error) { +// PackHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption +func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) { if packet.EncryptionLevel == protocol.EncryptionForwardSecure { return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment") } @@ -144,6 +149,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return nil, nil } p.stopWaiting = nil + p.ackFrame = nil raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer) if err != nil { @@ -185,9 +191,24 @@ func (p *packetPacker) composeNextPacket( var payloadLength protocol.ByteCount var payloadFrames []frames.Frame + // STOP_WAITING and ACK will always fit if p.stopWaiting != nil { - p.controlFrames = append(p.controlFrames, p.stopWaiting) + payloadFrames = append(payloadFrames, p.stopWaiting) + l, err := p.stopWaiting.MinLength(p.version) + if err != nil { + return nil, err + } + payloadLength += l } + if p.ackFrame != nil { + payloadFrames = append(payloadFrames, p.ackFrame) + l, err := p.ackFrame.MinLength(p.version) + if err != nil { + return nil, err + } + payloadLength += l + } + for len(p.controlFrames) > 0 { frame := p.controlFrames[len(p.controlFrames)-1] minLength, err := frame.MinLength(p.version) @@ -232,10 +253,13 @@ func (p *packetPacker) composeNextPacket( return payloadFrames, nil } -func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { - if swf, ok := f.(*frames.StopWaitingFrame); ok { - p.stopWaiting = swf - } else { +func (p *packetPacker) QueueControlFrameForNextPacket(frame frames.Frame) { + switch f := frame.(type) { + case *frames.StopWaitingFrame: + p.stopWaiting = f + case *frames.AckFrame: + p.ackFrame = f + default: p.controlFrames = append(p.controlFrames, f) } } diff --git a/packet_packer_test.go b/packet_packer_test.go index e5d8f6a4..d692f64d 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -638,7 +638,7 @@ var _ = Describe("Packet packer", func() { EncryptionLevel: protocol.EncryptionUnencrypted, Frames: []frames.Frame{sf}, } - p, err := packer.RetransmitNonForwardSecurePacket(packet) + p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(swf)) @@ -652,7 +652,7 @@ var _ = Describe("Packet packer", func() { EncryptionLevel: protocol.EncryptionSecure, Frames: []frames.Frame{sf}, } - p, err := packer.RetransmitNonForwardSecurePacket(packet) + p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(swf)) @@ -667,7 +667,7 @@ var _ = Describe("Packet packer", func() { EncryptionLevel: protocol.EncryptionSecure, Frames: []frames.Frame{sf}, } - p, err := packer.RetransmitNonForwardSecurePacket(packet) + p, err := packer.PackHandshakeRetransmission(packet) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) }) @@ -684,7 +684,7 @@ var _ = Describe("Packet packer", func() { }, }, } - _, err := packer.RetransmitNonForwardSecurePacket(packet) + _, err := packer.PackHandshakeRetransmission(packet) Expect(err).To(MatchError("PacketPacker BUG: packet too large")) }) @@ -692,13 +692,13 @@ var _ = Describe("Packet packer", func() { p := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionForwardSecure, } - _, err := packer.RetransmitNonForwardSecurePacket(p) + _, err := packer.PackHandshakeRetransmission(p) Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")) }) It("refuses to retransmit packets without a StopWaitingFrame", func() { packer.stopWaiting = nil - _, err := packer.RetransmitNonForwardSecurePacket(&ackhandler.Packet{ + _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, }) Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")) @@ -707,14 +707,16 @@ var _ = Describe("Packet packer", func() { Context("packing ACK packets", func() { It("packs ACK packets", func() { - p, err := packer.PackAckPacket(&frames.AckFrame{}) + packer.QueueControlFrameForNextPacket(&frames.AckFrame{}) + p, err := packer.PackAckPacket() Expect(err).NotTo(HaveOccurred()) Expect(p.frames).To(Equal([]frames.Frame{&frames.AckFrame{DelayTime: math.MaxInt64}})) }) It("packs ACK packets with SWFs", func() { + packer.QueueControlFrameForNextPacket(&frames.AckFrame{}) packer.QueueControlFrameForNextPacket(&frames.StopWaitingFrame{}) - p, err := packer.PackAckPacket(&frames.AckFrame{}) + p, err := packer.PackAckPacket() Expect(err).NotTo(HaveOccurred()) Expect(p.frames).To(Equal([]frames.Frame{ &frames.AckFrame{DelayTime: math.MaxInt64}, diff --git a/session.go b/session.go index 8fd702af..1d50ac13 100644 --- a/session.go +++ b/session.go @@ -570,19 +570,24 @@ func (s *session) sendPacket() error { for _, wuf := range windowUpdateFrames { s.packer.QueueControlFrameForNextPacket(wuf) } + + ack := s.receivedPacketHandler.GetAckFrame() + if ack != nil { + s.packer.QueueControlFrameForNextPacket(ack) + } + // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { if !s.sentPacketHandler.SendingAllowed() { - // If we aren't allowed to send, at least try sending an ACK frame - ack := s.receivedPacketHandler.GetAckFrame() if ack == nil { return nil } + // If we aren't allowed to send, at least try sending an ACK frame swf := s.sentPacketHandler.GetStopWaitingFrame(false) if swf != nil { s.packer.QueueControlFrameForNextPacket(swf) } - packet, err := s.packer.PackAckPacket(ack) + packet, err := s.packer.PackAckPacket() if err != nil { return err } @@ -603,7 +608,7 @@ func (s *session) sendPacket() error { } utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) s.packer.QueueControlFrameForNextPacket(s.sentPacketHandler.GetStopWaitingFrame(true)) - packet, err := s.packer.RetransmitNonForwardSecurePacket(retransmitPacket) + packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) if err != nil { return err } @@ -630,10 +635,6 @@ func (s *session) sendPacket() error { } } - ack := s.receivedPacketHandler.GetAckFrame() - if ack != nil { - s.packer.QueueControlFrameForNextPacket(ack) - } hasRetransmission := s.streamFramer.HasFramesForRetransmission() if ack != nil || hasRetransmission { swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) @@ -654,6 +655,7 @@ func (s *session) sendPacket() error { s.packer.QueueControlFrameForNextPacket(f) } windowUpdateFrames = nil + ack = nil s.nextAckScheduledTime = time.Time{} } }