implement a session method to tell if the handshake is already complete

This commit is contained in:
Marten Seemann
2019-07-06 11:05:00 +07:00
parent 264eaf2a7b
commit 6eb72f712d
2 changed files with 52 additions and 2 deletions

View File

@@ -135,8 +135,10 @@ type session struct {
connectionClosePacket *packedPacket
packetsReceivedAfterClose int
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc
undecryptablePackets []*receivedPacket
@@ -356,6 +358,7 @@ func (s *session) postSetup() error {
s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
s.timer = utils.NewTimer()
now := time.Now()
@@ -465,6 +468,10 @@ runLoop:
return closeErr.err
}
func (s *session) HandshakeComplete() context.Context {
return s.handshakeCtx
}
func (s *session) Context() context.Context {
return s.ctx
}
@@ -505,6 +512,7 @@ func (s *session) idleTimeoutStartTime() time.Time {
func (s *session) handleHandshakeComplete() {
s.handshakeComplete = true
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
s.handshakeCtxCancel()
s.sessionRunner.OnHandshakeComplete(s)
// The client completes the handshake first (after sending the CFIN).

View File

@@ -1153,6 +1153,48 @@ var _ = Describe("Session", func() {
})
})
It("cancels the HandshakeComplete context when the handshake completes", func() {
sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any())
packer.EXPECT().PackPacket().AnyTimes()
finishHandshake := make(chan struct{})
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
close(sess.handshakeCompleteChan)
sess.run()
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
//make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Retire(gomock.Any())
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
Expect(sess.Close()).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
packer.EXPECT().PackPacket().AnyTimes()
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Retire(gomock.Any())
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
sess.run()
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
sess.closeLocal(errors.New("handshake error"))
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("sends a 1-RTT packet when the handshake completes", func() {
done := make(chan struct{})
gomock.InOrder(