diff --git a/session.go b/session.go index 509894f6..845e98e2 100644 --- a/session.go +++ b/session.go @@ -135,8 +135,10 @@ type session struct { connectionClosePacket *packedPacket packetsReceivedAfterClose int - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + handshakeCtx context.Context + handshakeCtxCancel context.CancelFunc undecryptablePackets []*receivedPacket @@ -356,6 +358,7 @@ func (s *session) postSetup() error { s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) s.timer = utils.NewTimer() now := time.Now() @@ -465,6 +468,10 @@ runLoop: return closeErr.err } +func (s *session) HandshakeComplete() context.Context { + return s.handshakeCtx +} + func (s *session) Context() context.Context { return s.ctx } @@ -505,6 +512,7 @@ func (s *session) idleTimeoutStartTime() time.Time { func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again + s.handshakeCtxCancel() s.sessionRunner.OnHandshakeComplete(s) // The client completes the handshake first (after sending the CFIN). diff --git a/session_test.go b/session_test.go index 3a5a4d02..f46b7581 100644 --- a/session_test.go +++ b/session_test.go @@ -1153,6 +1153,48 @@ var _ = Describe("Session", func() { }) }) + It("cancels the HandshakeComplete context when the handshake completes", func() { + sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) + packer.EXPECT().PackPacket().AnyTimes() + finishHandshake := make(chan struct{}) + go func() { + defer GinkgoRecover() + <-finishHandshake + cryptoSetup.EXPECT().RunHandshake() + close(sess.handshakeCompleteChan) + sess.run() + }() + handshakeCtx := sess.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + close(finishHandshake) + Eventually(handshakeCtx.Done()).Should(BeClosed()) + //make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + Expect(sess.Close()).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + + It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { + packer.EXPECT().PackPacket().AnyTimes() + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake() + sess.run() + }() + handshakeCtx := sess.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + sess.closeLocal(errors.New("handshake error")) + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) gomock.InOrder(