diff --git a/session.go b/session.go index 922ef99a2..3ec8c9ec2 100644 --- a/session.go +++ b/session.go @@ -396,13 +396,23 @@ runLoop: s.keepAlivePingSent = true } - if !s.sentPacketHandler.SendingAllowed() { // if congestion limited, at least try sending an ACK frame + sendingAllowed := s.sentPacketHandler.SendingAllowed() + if !sendingAllowed { // if congestion limited, at least try sending an ACK frame if err := s.maybeSendAckOnlyPacket(); err != nil { s.closeLocal(err) } } else { - if err := s.sendPacket(); err != nil { - s.closeLocal(err) + // repeatedly try sending until we don't have any more data, or run out of the congestion window + for sendingAllowed { + sentPacket, err := s.sendPacket() + if err != nil { + s.closeLocal(err) + break + } + if !sentPacket { + break + } + sendingAllowed = s.sentPacketHandler.SendingAllowed() } } @@ -723,7 +733,7 @@ func (s *session) maybeSendAckOnlyPacket() error { return s.sendPackedPacket(packet) } -func (s *session) sendPacket() error { +func (s *session) sendPacket() (bool, error) { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { @@ -739,69 +749,65 @@ func (s *session) sendPacket() error { s.packer.QueueControlFrame(ack) } - // Repeatedly try sending until we don't have any more data, or run out of the congestion window + // check for retransmissions first for { - if !s.sentPacketHandler.SendingAllowed() { - return nil + retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() + if retransmitPacket == nil { + break } - // check for retransmissions first - for { - retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() - if retransmitPacket == nil { - break + // retransmit handshake packets + if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { + if s.handshakeComplete { + // don't retransmit handshake packets when the handshake is complete + continue } - - if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - if s.handshakeComplete { - // Don't retransmit handshake packets when the handshake is complete - continue - } - utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) - if !s.version.UsesIETFFrameFormat() { - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - } - packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) - if err != nil { - return err - } - if err := s.sendPackedPacket(packet); err != nil { - return err - } - } else { - utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) - // resend the frames that were in the packet - for _, frame := range retransmitPacket.GetFramesForRetransmission() { - // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window - switch f := frame.(type) { - case *wire.StreamFrame: - s.streamFramer.AddFrameForRetransmission(f) - default: - s.packer.QueueControlFrame(frame) - } - } + utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + if !s.version.UsesIETFFrameFormat() { + s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) } + packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) + if err != nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) { - if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil { - s.packer.QueueControlFrame(swf) + // queue all retransmittable frames sent in forward-secure packets + utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + // resend the frames that were in the packet + for _, frame := range retransmitPacket.GetFramesForRetransmission() { + // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window + switch f := frame.(type) { + case *wire.StreamFrame: + s.streamFramer.AddFrameForRetransmission(f) + default: + s.packer.QueueControlFrame(frame) } } - // add a retransmittable frame - if s.sentPacketHandler.ShouldSendRetransmittablePacket() { - s.packer.MakeNextPacketRetransmittable() - } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - ack = nil } + + hasRetransmission := s.streamFramer.HasFramesForRetransmission() + if !s.version.UsesIETFFrameFormat() && (ack != nil || hasRetransmission) { + if swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + // add a retransmittable frame + if s.sentPacketHandler.ShouldSendRetransmittablePacket() { + s.packer.MakeNextPacketRetransmittable() + } + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil } func (s *session) sendPackedPacket(packet *packedPacket) error { diff --git a/session_test.go b/session_test.go index 32926fac9..ca909a8d5 100644 --- a/session_test.go +++ b/session_test.go @@ -739,8 +739,9 @@ var _ = Describe("Session", func() { packetNumber := protocol.PacketNumber(0x035e) err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) Expect(err).ToNot(HaveOccurred()) - err = sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) @@ -752,15 +753,14 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket().Return(true) - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().SendingAllowed() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(HaveLen(2)) Expect(p.Frames).To(ContainElement(ack)) }) sess.sentPacketHandler = sph - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) @@ -773,16 +773,15 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket() - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().SendingAllowed() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.MaxDataFrame{ByteOffset: 0x1337}, })) }) sess.sentPacketHandler = sph - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) }) It("adds MAX_STREAM_DATA frames", func() { @@ -794,14 +793,13 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket() - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().SendingAllowed() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) }) sess.sentPacketHandler = sph - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) }) It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { @@ -813,16 +811,42 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLeastUnacked().AnyTimes() sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket() - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().SendingAllowed() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{ &wire.BlockedFrame{Offset: 1337}, })) }) sess.sentPacketHandler = sph - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) + }) + + It("sends multiple packets", func() { + sess.queueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().DequeuePacketForRetransmission().Times(2) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().ShouldSendRetransmittablePacket().Times(2) + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().SendingAllowed().Do(func() { // after sending the first packet + // make sure there's something to send + sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 2}) + }).Return(true).Times(2) // allow 2 packets... + sph.EXPECT().SendingAllowed() // ...then report that we're congestion limited + sess.sentPacketHandler = sph + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(2)) + // make the go routine return + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) It("sends public reset", func() { @@ -843,8 +867,6 @@ var _ = Describe("Session", func() { sph.EXPECT().GetStopWaitingFrame(gomock.Any()) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().ShouldSendRetransmittablePacket() - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().SendingAllowed() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { sentPacket = p }) @@ -855,8 +877,9 @@ var _ = Describe("Session", func() { sess.streamFramer.AddFrameForRetransmission(f) _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) Expect(sentPacket.PacketNumber).To(Equal(protocol.PacketNumber(0x1337 + 9))) Expect(sentPacket.Frames).To(ContainElement(f)) @@ -937,8 +960,6 @@ var _ = Describe("Session", func() { sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLeastUnacked().AnyTimes() - sph.EXPECT().SendingAllowed().Return(true) - sph.EXPECT().ShouldSendRetransmittablePacket() sess.sentPacketHandler = sph sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) @@ -952,13 +973,13 @@ var _ = Describe("Session", func() { Frames: []wire.Frame{sf}, EncryptionLevel: protocol.EncryptionUnencrypted, }) - sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.Frames).To(Equal([]wire.Frame{swf, sf})) }) - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) @@ -970,13 +991,13 @@ var _ = Describe("Session", func() { Frames: []wire.Frame{sf}, EncryptionLevel: protocol.EncryptionUnencrypted, }) - sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.Frames).To(Equal([]wire.Frame{sf})) }) - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) @@ -989,8 +1010,10 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionSecure, }) sph.EXPECT().DequeuePacketForRetransmission() - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) + sph.EXPECT().ShouldSendRetransmittablePacket() + sent, err := sess.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeFalse()) Expect(mconn.written).To(BeEmpty()) }) }) @@ -1009,13 +1032,14 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{swf, f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) }) - sph.EXPECT().SendingAllowed() - err := sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) @@ -1031,17 +1055,18 @@ var _ = Describe("Session", func() { EncryptionLevel: protocol.EncryptionForwardSecure, }) sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.Frames).To(Equal([]wire.Frame{f})) Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) }) - sph.EXPECT().SendingAllowed() - err := sess.sendPacket() + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) - It("sends a StreamFrame from a packet queued for retransmission", func() { + It("sends a STREAM frame from a packet queued for retransmission", func() { f1 := wire.StreamFrame{ StreamID: 0x5, Data: []byte("foobar"), @@ -1064,11 +1089,13 @@ var _ = Describe("Session", func() { sph.EXPECT().DequeuePacketForRetransmission().Return(p2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().SendingAllowed() - - err := sess.sendPacket() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.Frames).To(HaveLen(3)) + }) + sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) packet := <-mconn.written Expect(packet).To(ContainSubstring("foobar")) @@ -1130,16 +1157,21 @@ var _ = Describe("Session", func() { It("sets the timer to the ack timer", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().GetAckFrame().Return(&wire.AckFrame{LargestAcked: 0x1337}) + rph.EXPECT().GetAckFrame() rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)).MinTimes(1) sess.receivedPacketHandler = rph + done := make(chan struct{}) go func() { defer GinkgoRecover() sess.run() + close(done) }() - defer sess.Close(nil) time.Sleep(10 * time.Millisecond) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) + // make sure the go routine returns + sess.Close(nil) + Eventually(done).Should(BeClosed()) }) })