From 5102294991f7ee83e51da4d8df7213726a97009e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 2 Oct 2018 18:25:16 -0700 Subject: [PATCH] use separate chans to signal handshake events and handshake completion --- internal/handshake/crypto_setup_client.go | 31 ++--- .../handshake/crypto_setup_client_test.go | 7 +- internal/handshake/crypto_setup_server.go | 11 +- .../handshake/crypto_setup_server_test.go | 9 +- internal/handshake/crypto_setup_tls.go | 37 +++--- internal/handshake/crypto_setup_tls_test.go | 9 +- session.go | 117 ++++++++++-------- session_test.go | 41 +++--- 8 files changed, 149 insertions(+), 113 deletions(-) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 655449c7..3443161b 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -50,8 +50,9 @@ type cryptoSetupClient struct { secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - paramsChan chan<- TransportParameters - handshakeEvent chan<- struct{} + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} + handshakeComplete chan<- struct{} params *TransportParameters @@ -75,6 +76,7 @@ func NewCryptoSetupClient( params *TransportParameters, paramsChan chan<- TransportParameters, handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger, @@ -85,17 +87,18 @@ func NewCryptoSetupClient( } divNonceChan := make(chan struct{}) cs := &cryptoSetupClient{ - cryptoStream: cryptoStream, - hostname: tlsConf.ServerName, - connID: connID, - version: version, - certManager: crypto.NewCertManager(tlsConf), - params: params, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - nullAEAD: nullAEAD, - paramsChan: paramsChan, - handshakeEvent: handshakeEvent, - initialVersion: initialVersion, + cryptoStream: cryptoStream, + hostname: tlsConf.ServerName, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConf), + params: params, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + nullAEAD: nullAEAD, + paramsChan: paramsChan, + handshakeEvent: handshakeEvent, + handshakeComplete: handshakeComplete, + initialVersion: initialVersion, // The server might have sent greased versions in the Version Negotiation packet. // We need strip those from the list, since they won't be included in the handshake tag. negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions), @@ -158,7 +161,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { // blocks until the session has received the parameters h.paramsChan <- *params h.handshakeEvent <- struct{}{} - close(h.handshakeEvent) + close(h.handshakeComplete) default: return qerr.InvalidCryptoMessageType } diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 0da07224..aceb3286 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -91,6 +91,7 @@ var _ = Describe("Client Crypto Setup", func() { keyDerivationCalledWith *keyDerivationValues shloMap map[Tag][]byte handshakeEvent chan struct{} + handshakeComplete chan struct{} paramsChan chan TransportParameters ) @@ -120,6 +121,7 @@ var _ = Describe("Client Crypto Setup", func() { // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) handshakeEvent = make(chan struct{}, 2) + handshakeComplete = make(chan struct{}) csInt, err := NewCryptoSetupClient( stream, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, @@ -128,6 +130,7 @@ var _ = Describe("Client Crypto Setup", func() { &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, paramsChan, handshakeEvent, + handshakeComplete, protocol.Version39, nil, utils.DefaultLogger, @@ -445,7 +448,7 @@ var _ = Describe("Client Crypto Setup", func() { Expect(params.IdleTimeout).To(Equal(13 * time.Second)) }) - It("closes the handshakeEvent chan when receiving an SHLO", func() { + It("closes the handshakeComplete chan when receiving an SHLO", func() { HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) done := make(chan struct{}) go func() { @@ -455,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() { close(done) }() Eventually(handshakeEvent).Should(Receive()) - Eventually(handshakeEvent).Should(BeClosed()) + Eventually(handshakeComplete).Should(BeClosed()) // make the go routine return stream.close() Eventually(done).Should(BeClosed()) diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index 552e8297..b7a46480 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -42,9 +42,10 @@ type cryptoSetupServer struct { receivedSecurePacket bool sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written - receivedParams bool - paramsChan chan<- TransportParameters - handshakeEvent chan<- struct{} + receivedParams bool + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} + handshakeComplete chan<- struct{} keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction @@ -77,6 +78,7 @@ func NewCryptoSetup( acceptSTK func(net.Addr, *Cookie) bool, paramsChan chan<- TransportParameters, handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, logger utils.Logger, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) @@ -99,6 +101,7 @@ func NewCryptoSetup( sentSHLO: make(chan struct{}), paramsChan: paramsChan, handshakeEvent: handshakeEvent, + handshakeComplete: handshakeComplete, logger: logger, }, nil } @@ -210,7 +213,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu h.receivedForwardSecurePacket = true // wait for the send on the handshakeEvent chan <-h.sentSHLO - close(h.handshakeEvent) + close(h.handshakeComplete) } return res, protocol.EncryptionForwardSecure, nil } diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 7b4a0e94..0e7133be 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -122,6 +122,7 @@ var _ = Describe("Server Crypto Setup", func() { stream *mockStream paramsChan chan TransportParameters handshakeEvent chan struct{} + handshakeComplete chan struct{} nonce32 []byte versionTag []byte validSTK []byte @@ -144,6 +145,7 @@ var _ = Describe("Server Crypto Setup", func() { // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) handshakeEvent = make(chan struct{}, 2) + handshakeComplete = make(chan struct{}) stream = newMockStream() kex = &mockKEX{} signer = &mockSigner{} @@ -169,6 +171,7 @@ var _ = Describe("Server Crypto Setup", func() { nil, paramsChan, handshakeEvent, + handshakeComplete, utils.DefaultLogger, ) Expect(err).NotTo(HaveOccurred()) @@ -318,7 +321,7 @@ var _ = Describe("Server Crypto Setup", func() { Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) Expect(handshakeEvent).To(Receive()) // for the switch to forward secure - Expect(handshakeEvent).ToNot(BeClosed()) + Expect(handshakeComplete).ToNot(BeClosed()) }) It("rejects client nonces that have the wrong length", func() { @@ -351,7 +354,7 @@ var _ = Describe("Server Crypto Setup", func() { Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(handshakeEvent).To(Receive()) // for the switch to forward secure - Expect(handshakeEvent).ToNot(BeClosed()) + Expect(handshakeComplete).ToNot(BeClosed()) }) It("recognizes inchoate CHLOs missing SCID", func() { @@ -629,7 +632,7 @@ var _ = Describe("Server Crypto Setup", func() { cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) Expect(err).ToNot(HaveOccurred()) - Expect(handshakeEvent).To(BeClosed()) + Expect(handshakeComplete).To(BeClosed()) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 2e5dd025..389ecbd1 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -23,9 +23,10 @@ type cryptoSetupTLS struct { nullAEAD crypto.AEAD aead crypto.AEAD - tls mintTLS - conn *cryptoStreamConn - handshakeEvent chan<- struct{} + tls mintTLS + conn *cryptoStreamConn + handshakeEvent chan<- struct{} + handshakeComplete chan<- struct{} } var _ CryptoSetupTLS = &cryptoSetupTLS{} @@ -36,6 +37,7 @@ func NewCryptoSetupTLSServer( connID protocol.ConnectionID, config *mint.Config, handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, version protocol.VersionNumber, ) (CryptoSetupTLS, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) @@ -45,12 +47,13 @@ func NewCryptoSetupTLSServer( conn := newCryptoStreamConn(cryptoStream) tls := mint.Server(conn, config) return &cryptoSetupTLS{ - tls: tls, - conn: conn, - nullAEAD: nullAEAD, - perspective: protocol.PerspectiveServer, - keyDerivation: crypto.DeriveAESKeys, - handshakeEvent: handshakeEvent, + tls: tls, + conn: conn, + nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, + handshakeComplete: handshakeComplete, }, nil } @@ -60,6 +63,7 @@ func NewCryptoSetupTLSClient( connID protocol.ConnectionID, config *mint.Config, handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, version protocol.VersionNumber, ) (CryptoSetupTLS, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) @@ -69,12 +73,13 @@ func NewCryptoSetupTLSClient( conn := newCryptoStreamConn(cryptoStream) tls := mint.Client(conn, config) return &cryptoSetupTLS{ - tls: tls, - conn: conn, - perspective: protocol.PerspectiveClient, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - handshakeEvent: handshakeEvent, + tls: tls, + conn: conn, + perspective: protocol.PerspectiveClient, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, + handshakeComplete: handshakeComplete, }, nil } @@ -101,7 +106,7 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error { h.mutex.Unlock() h.handshakeEvent <- struct{}{} - close(h.handshakeEvent) + close(h.handshakeComplete) return nil } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 85c087ea..12b2e9cd 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -20,17 +20,20 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e var _ = Describe("TLS Crypto Setup", func() { var ( - cs *cryptoSetupTLS - handshakeEvent chan struct{} + cs *cryptoSetupTLS + handshakeEvent chan struct{} + handshakeComplete chan struct{} ) BeforeEach(func() { handshakeEvent = make(chan struct{}, 2) + handshakeComplete = make(chan struct{}) css, err := NewCryptoSetupTLSServer( newCryptoStreamConn(bytes.NewBuffer([]byte{})), protocol.ConnectionID{}, &mint.Config{}, handshakeEvent, + handshakeComplete, protocol.VersionTLS, ) Expect(err).ToNot(HaveOccurred()) @@ -54,7 +57,7 @@ var _ = Describe("TLS Crypto Setup", func() { err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) Expect(handshakeEvent).To(Receive()) - Expect(handshakeEvent).To(BeClosed()) + Expect(handshakeComplete).To(BeClosed()) }) It("handshakes until it is connected", func() { diff --git a/session.go b/session.go index 0c33d2a0..d236c527 100644 --- a/session.go +++ b/session.go @@ -119,8 +119,9 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. - handshakeEvent <-chan struct{} - handshakeComplete bool + handshakeEvent <-chan struct{} + handshakeCompleteChan <-chan struct{} // is closed when the handshake completes + handshakeComplete bool receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstForwardSecurePacket bool @@ -162,17 +163,19 @@ func newSession( logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) + handshakeCompleteChan := make(chan struct{}) s := &session{ - conn: conn, - sessionRunner: sessionRunner, - srcConnID: srcConnID, - destConnID: destConnID, - perspective: protocol.PerspectiveServer, - version: v, - config: config, - handshakeEvent: handshakeEvent, - paramsChan: paramsChan, - logger: logger, + conn: conn, + sessionRunner: sessionRunner, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveServer, + version: v, + config: config, + handshakeEvent: handshakeEvent, + handshakeCompleteChan: handshakeCompleteChan, + paramsChan: paramsChan, + logger: logger, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -197,6 +200,7 @@ func newSession( s.config.AcceptCookie, paramsChan, handshakeEvent, + handshakeCompleteChan, s.logger, ) if err != nil { @@ -238,17 +242,19 @@ var newClientSession = func( logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) + handshakeCompleteChan := make(chan struct{}) s := &session{ - conn: conn, - sessionRunner: sessionRunner, - srcConnID: srcConnID, - destConnID: destConnID, - perspective: protocol.PerspectiveClient, - version: v, - config: config, - handshakeEvent: handshakeEvent, - paramsChan: paramsChan, - logger: logger, + conn: conn, + sessionRunner: sessionRunner, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveClient, + version: v, + config: config, + handshakeEvent: handshakeEvent, + handshakeCompleteChan: handshakeCompleteChan, + paramsChan: paramsChan, + logger: logger, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -266,6 +272,7 @@ var newClientSession = func( transportParams, paramsChan, handshakeEvent, + handshakeCompleteChan, initialVersion, negotiatedVersions, s.logger, @@ -307,16 +314,18 @@ func newTLSServerSession( v protocol.VersionNumber, ) (quicSession, error) { handshakeEvent := make(chan struct{}, 1) + handshakeCompleteChan := make(chan struct{}) s := &session{ - conn: conn, - sessionRunner: runner, - config: config, - srcConnID: srcConnID, - destConnID: destConnID, - perspective: protocol.PerspectiveServer, - version: v, - handshakeEvent: handshakeEvent, - logger: logger, + conn: conn, + sessionRunner: runner, + config: config, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveServer, + version: v, + handshakeEvent: handshakeEvent, + handshakeCompleteChan: handshakeCompleteChan, + logger: logger, } s.preSetup() cs, err := handshake.NewCryptoSetupTLSServer( @@ -324,6 +333,7 @@ func newTLSServerSession( origConnID, mintConf, handshakeEvent, + handshakeCompleteChan, v, ) if err != nil { @@ -370,17 +380,19 @@ var newTLSClientSession = func( v protocol.VersionNumber, ) (quicSession, error) { handshakeEvent := make(chan struct{}, 1) + handshakeCompleteChan := make(chan struct{}) s := &session{ - conn: conn, - sessionRunner: runner, - config: conf, - srcConnID: srcConnID, - destConnID: destConnID, - perspective: protocol.PerspectiveClient, - version: v, - handshakeEvent: handshakeEvent, - paramsChan: paramsChan, - logger: logger, + conn: conn, + sessionRunner: runner, + config: conf, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveClient, + version: v, + handshakeEvent: handshakeEvent, + handshakeCompleteChan: handshakeCompleteChan, + paramsChan: paramsChan, + logger: logger, } s.preSetup() cs, err := handshake.NewCryptoSetupTLSClient( @@ -388,6 +400,7 @@ var newTLSClientSession = func( s.destConnID, mintConf, handshakeEvent, + handshakeCompleteChan, v, ) if err != nil { @@ -463,9 +476,10 @@ runLoop: select { case closeErr = <-s.closeChan: break runLoop - case _, ok := <-s.handshakeEvent: - // when the handshake is completed, the channel will be closed - s.handleHandshakeEvent(!ok) + case <-s.handshakeEvent: + s.tryDecryptingQueuedPackets() + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() default: } @@ -497,9 +511,10 @@ runLoop: case p := <-s.paramsChan: s.processTransportParameters(&p) continue - case _, ok := <-s.handshakeEvent: - // when the handshake is completed, the channel will be closed - s.handleHandshakeEvent(!ok) + case <-s.handshakeEvent: + s.tryDecryptingQueuedPackets() + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() } now := time.Now() @@ -590,13 +605,9 @@ func (s *session) maybeResetTimer() { s.timer.Reset(deadline) } -func (s *session) handleHandshakeEvent(completed bool) { - if !completed { - s.tryDecryptingQueuedPackets() - return - } +func (s *session) handleHandshakeComplete() { s.handshakeComplete = true - s.handshakeEvent = nil // prevent this case from ever being selected again + s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.sessionRunner.onHandshakeComplete(s) // In gQUIC, the server completes the handshake first (after sending the SHLO). diff --git a/session_test.go b/session_test.go index 026acd48..b78bf59f 100644 --- a/session_test.go +++ b/session_test.go @@ -93,14 +93,15 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - sessionRunner *MockSessionRunner - scfg *handshake.ServerConfig - mconn *mockConnection - cryptoSetup *mockCryptoSetup - streamManager *MockStreamManager - packer *MockPacker - handshakeChan chan<- struct{} + sess *session + sessionRunner *MockSessionRunner + scfg *handshake.ServerConfig + mconn *mockConnection + cryptoSetup *mockCryptoSetup + streamManager *MockStreamManager + packer *MockPacker + handshakeChan chan<- struct{} + handshakeCompleteChan chan<- struct{} ) BeforeEach(func() { @@ -119,9 +120,11 @@ var _ = Describe("Session", func() { _ func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, handshakeChanP chan<- struct{}, + handshakeCompleteChanP chan<- struct{}, _ utils.Logger, ) (handshake.CryptoSetup, error) { handshakeChan = handshakeChanP + handshakeCompleteChan = handshakeCompleteChanP return cryptoSetup, nil } @@ -178,6 +181,7 @@ var _ = Describe("Session", func() { cookieFunc func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, _ chan<- struct{}, + _ chan<- struct{}, _ utils.Logger, ) (handshake.CryptoSetup, error) { cookieVerify = cookieFunc @@ -1255,7 +1259,7 @@ var _ = Describe("Session", func() { }) It("calls the onHandshakeComplete callback when the handshake completes", func() { - close(handshakeChan) + close(handshakeCompleteChan) sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) go func() { defer GinkgoRecover() @@ -1469,7 +1473,7 @@ var _ = Describe("Session", func() { return &packedPacket{}, nil }) sess.config.IdleTimeout = 0 - close(handshakeChan) + close(handshakeCompleteChan) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -1578,11 +1582,11 @@ var _ = Describe("Session", func() { var _ = Describe("Client Session", func() { var ( - sess *session - sessionRunner *MockSessionRunner - packer *MockPacker - mconn *mockConnection - handshakeChan chan<- struct{} + sess *session + sessionRunner *MockSessionRunner + packer *MockPacker + mconn *mockConnection + handshakeCompleteChan chan<- struct{} cryptoSetup *mockCryptoSetup ) @@ -1598,12 +1602,13 @@ var _ = Describe("Client Session", func() { _ *tls.Config, _ *handshake.TransportParameters, _ chan<- handshake.TransportParameters, - handshakeChanP chan<- struct{}, + _ chan<- struct{}, + handshakeCompleteChanP chan<- struct{}, _ protocol.VersionNumber, _ []protocol.VersionNumber, _ utils.Logger, ) (handshake.CryptoSetup, error) { - handshakeChan = handshakeChanP + handshakeCompleteChan = handshakeCompleteChanP return cryptoSetup, nil } @@ -1641,7 +1646,7 @@ var _ = Describe("Client Session", func() { }), packer.EXPECT().PackPacket().AnyTimes(), ) - close(handshakeChan) + close(handshakeCompleteChan) go func() { defer GinkgoRecover() sess.run()