diff --git a/session_test.go b/session_test.go index edaf044c..677cf580 100644 --- a/session_test.go +++ b/session_test.go @@ -388,29 +388,35 @@ var _ = Describe("Session", func() { Context("closing", func() { var ( - runErr error + runErr chan error expectedRunErr error ) BeforeEach(func() { - Eventually(areSessionsRunning).Should(BeFalse()) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - runErr = sess.run() - }() - Eventually(areSessionsRunning).Should(BeTrue()) + runErr = make(chan error, 1) expectedRunErr = nil }) AfterEach(func() { if expectedRunErr != nil { - Expect(runErr).To(MatchError(expectedRunErr)) + Eventually(runErr).Should(Receive(MatchError(expectedRunErr))) + } else { + Eventually(runErr).Should(Receive()) } }) + runSession := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + runErr <- sess.run() + }() + Eventually(areSessionsRunning).Should(BeTrue()) + } + It("shuts down without error", func() { sess.handshakeComplete = true + runSession() streamManager.EXPECT().CloseWithError(qerr.NewApplicationError(0, "")) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -429,6 +435,7 @@ var _ = Describe("Session", func() { }) It("only closes once", func() { + runSession() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -442,6 +449,7 @@ var _ = Describe("Session", func() { }) It("closes with an error", func() { + runSession() streamManager.EXPECT().CloseWithError(qerr.NewApplicationError(0x1337, "test error")) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -459,6 +467,7 @@ var _ = Describe("Session", func() { }) It("includes the frame type in transport-level close frames", func() { + runSession() testErr := qerr.NewErrorWithFrameType(0x1337, 0x42, "test error") streamManager.EXPECT().CloseWithError(testErr) expectReplaceWithClosed() @@ -478,6 +487,7 @@ var _ = Describe("Session", func() { }) It("closes the session in order to recreate it", func() { + runSession() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() @@ -489,6 +499,7 @@ var _ = Describe("Session", func() { }) It("destroys the session", func() { + runSession() testErr := errors.New("close") streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() @@ -501,6 +512,7 @@ var _ = Describe("Session", func() { }) It("cancels the context when the run loop exists", func() { + runSession() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -524,6 +536,7 @@ var _ = Describe("Session", func() { unpacker := NewMockUnpacker(mockCtrl) sess.handshakeConfirmed = true sess.unpacker = unpacker + runSession() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() @@ -991,12 +1004,10 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { + var sessionDone chan struct{} + BeforeEach(func() { - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - go func() { - defer GinkgoRecover() - sess.run() - }() + sessionDone = make(chan struct{}) }) AfterEach(func() { @@ -1008,27 +1019,38 @@ var _ = Describe("Session", func() { qlogger.EXPECT().Export() sess.shutdown() Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(sessionDone).Should(BeClosed()) }) + runSession := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + close(sessionDone) + }() + } + It("sends packets", func() { sess.handshakeConfirmed = true + runSession() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) - sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) - mconn.EXPECT().Write(gomock.Any()) + packer.EXPECT().PackPacket().Return(nil, nil) + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []wire.Frame{}) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) }) It("doesn't send packets if there's nothing to send", func() { sess.handshakeConfirmed = true + runSession() packer.EXPECT().PackPacket().Return(nil, nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeFalse()) + sess.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) It("sends ACK only packets", func() { @@ -1037,6 +1059,7 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAck) sph.EXPECT().ShouldSendNumPackets().Return(1000) packer.EXPECT().MaybePackAckPacket(false) + runSession() sess.sentPacketHandler = sph Expect(sess.sendPackets()).To(Succeed()) }) @@ -1045,14 +1068,17 @@ var _ = Describe("Session", func() { sess.handshakeConfirmed = true fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + fc.EXPECT().IsNewlyBlocked() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) + packer.EXPECT().PackPacket().Return(nil, nil) sess.connFlowController = fc - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, nil, []wire.Frame{}) - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.DataBlockedFrame{DataLimit: 1337}}})) }) @@ -1061,8 +1087,11 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendNone) + sph.EXPECT().TimeUntilSend().AnyTimes() sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) + runSession() + sess.scheduleSending() + time.Sleep(50 * time.Millisecond) }) for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { @@ -1089,7 +1118,7 @@ var _ = Describe("Session", func() { It("sends a probe packet", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() + sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().QueueProbePacket(encLevel) @@ -1099,15 +1128,18 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - Expect(sess.sendPackets()).To(Succeed()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) }) It("sends a PING as a probe packet", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() + sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().QueueProbePacket(encLevel).Return(false) @@ -1117,9 +1149,12 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - mconn.EXPECT().Write(gomock.Any()) + runSession() + sent := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(sent) }) qlogger.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - Expect(sess.sendPackets()).To(Succeed()) + sess.scheduleSending() + Eventually(sent).Should(BeClosed()) // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) @@ -1581,12 +1616,7 @@ var _ = Describe("Session", func() { }) Context("transport parameters", func() { - It("process transport parameters received from the client", func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() + It("processes transport parameters received from the client", func() { params := &wire.TransportParameters{ MaxIdleTimeout: 90 * time.Second, InitialMaxStreamDataBidiLocal: 0x5000, @@ -1604,19 +1634,6 @@ var _ = Describe("Session", func() { qlogger.EXPECT().ReceivedTransportParameters(params) sess.processTransportParameters(params) Expect(sess.earlySessionReady()).To(BeClosed()) - - // make the go routine return - streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) - s.shutdown() - }).Times(4) // initial connection ID + initial client dest conn ID + 2 newly issued conn IDs - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) - qlogger.EXPECT().Export() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) }) }) @@ -1766,7 +1783,6 @@ var _ = Describe("Session", func() { }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - sess.handshakeComplete = true expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2126,6 +2142,7 @@ var _ = Describe("Client Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) errChan <- sess.run() + close(errChan) }() }) @@ -2147,6 +2164,7 @@ var _ = Describe("Client Session", func() { expectClose() sess.shutdown() Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(errChan).Should(BeClosed()) }) It("uses the preferred_address connection ID", func() {