diff --git a/session_test.go b/session_test.go index 7f9b61d2..677cf580 100644 --- a/session_test.go +++ b/session_test.go @@ -1004,12 +1004,10 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { + var sessionDone chan struct{} + BeforeEach(func() { - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - go func() { - defer GinkgoRecover() - sess.run() - }() + sessionDone = make(chan struct{}) }) AfterEach(func() { @@ -1021,27 +1019,38 @@ var _ = Describe("Session", func() { qlogger.EXPECT().Export() sess.shutdown() Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(sessionDone).Should(BeClosed()) }) + runSession := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + close(sessionDone) + }() + } + It("sends packets", func() { sess.handshakeConfirmed = true + runSession() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) - sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) - mconn.EXPECT().Write(gomock.Any()) + packer.EXPECT().PackPacket().Return(nil, nil) + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []wire.Frame{}) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) }) It("doesn't send packets if there's nothing to send", func() { sess.handshakeConfirmed = true + runSession() packer.EXPECT().PackPacket().Return(nil, nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeFalse()) + sess.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) It("sends ACK only packets", func() { @@ -1050,6 +1059,7 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAck) sph.EXPECT().ShouldSendNumPackets().Return(1000) packer.EXPECT().MaybePackAckPacket(false) + runSession() sess.sentPacketHandler = sph Expect(sess.sendPackets()).To(Succeed()) }) @@ -1058,14 +1068,17 @@ var _ = Describe("Session", func() { sess.handshakeConfirmed = true fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + fc.EXPECT().IsNewlyBlocked() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) + packer.EXPECT().PackPacket().Return(nil, nil) sess.connFlowController = fc - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, nil, []wire.Frame{}) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.DataBlockedFrame{DataLimit: 1337}}})) }) @@ -1074,8 +1087,11 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendNone) + sph.EXPECT().TimeUntilSend().AnyTimes() sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) + runSession() + sess.scheduleSending() + time.Sleep(50 * time.Millisecond) }) for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { @@ -1102,7 +1118,7 @@ var _ = Describe("Session", func() { It("sends a probe packet", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() + sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().QueueProbePacket(encLevel) @@ -1112,15 +1128,18 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - Expect(sess.sendPackets()).To(Succeed()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) }) It("sends a PING as a probe packet", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() + sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().QueueProbePacket(encLevel).Return(false) @@ -1130,9 +1149,12 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - Expect(sess.sendPackets()).To(Succeed()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{}))