From e6aeb143a78f0c90ac6a3629d40e0eac6e2ff569 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 8 May 2017 21:34:14 +0800 Subject: [PATCH] simplify the blocking logic for the non-forward-secure session --- session.go | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/session.go b/session.go index 4662e9079..87ab49b58 100644 --- a/session.go +++ b/session.go @@ -83,10 +83,11 @@ type session struct { // this channel is passed to the CryptoSetup and receives the current encryption level // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel - handshakeComplete bool - handshakeChan chan struct{} // will be closed as soon as the handshake completes - handshakeErrorChan chan error + aeadChanged <-chan protocol.EncryptionLevel + handshakeComplete bool + // will be closed as soon as the handshake completes, and receive any error that might occur until then + // it is used to block WaitUntilHandshakeComplete() + handshakeCompleteChan chan error nextAckScheduledTime time.Time @@ -217,8 +218,7 @@ func (s *session) setup() { s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.aeadChanged = make(chan protocol.EncryptionLevel, 2) s.runClosed = make(chan struct{}) - s.handshakeChan = make(chan struct{}) - s.handshakeErrorChan = make(chan error, 1) + s.handshakeCompleteChan = make(chan error, 1) s.timer = time.NewTimer(0) s.lastNetworkActivityTime = now @@ -278,7 +278,7 @@ runLoop: if !ok { s.handshakeComplete = true aeadChanged = nil // prevent this case from ever being selected again - close(s.handshakeChan) + close(s.handshakeCompleteChan) } else { if l == protocol.EncryptionForwardSecure { s.packer.SetForwardSecure() @@ -310,7 +310,11 @@ runLoop: s.garbageCollectStreams() } - s.handshakeErrorChan <- closeErr.err + // only send the error the handshakeChan when the handshake is not completed yet + // otherwise this chan will already be closed + if !s.handshakeComplete { + s.handshakeCompleteChan <- closeErr.err + } s.handleCloseError(closeErr) close(s.runClosed) return closeErr.err @@ -758,12 +762,7 @@ func (s *session) OpenStreamSync() (Stream, error) { } func (s *session) WaitUntilHandshakeComplete() error { - select { - case <-s.handshakeChan: - return nil - case err := <-s.handshakeErrorChan: - return err - } + return <-s.handshakeCompleteChan } func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {