From e361d3c5cd07d76ab4a74b493de19c2e7f97a153 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 31 May 2019 16:31:50 +0800 Subject: [PATCH] use a callback to signal completion of the handshake --- internal/handshake/crypto_setup.go | 1 + internal/handshake/crypto_setup_test.go | 8 +++ internal/handshake/interface.go | 1 + .../handshake/mock_handshake_runner_test.go | 12 ++++ session.go | 29 ++++---- session_test.go | 72 ++++++++----------- 6 files changed, 65 insertions(+), 58 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index e6c5e55d..d1e91ab5 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -267,6 +267,7 @@ func (h *cryptoSetup) RunHandshake() { select { case <-handshakeComplete: // return when the handshake is done + h.runner.OnHandshakeComplete() case <-h.closeChan: // wait until the Handshake() go routine has returned close(h.messageChan) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 8c99469a..05adec5f 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -291,11 +291,13 @@ var _ = Describe("Crypto Setup TLS", func() { } handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { + var cHandshakeComplete bool cChunkChan, cInitialStream, cHandshakeStream := initStreams() cErrChan := make(chan error, 1) cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) + cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, @@ -309,11 +311,13 @@ var _ = Describe("Crypto Setup TLS", func() { ) Expect(err).ToNot(HaveOccurred()) + var sHandshakeComplete bool sChunkChan, sInitialStream, sHandshakeStream := initStreams() sErrChan := make(chan error, 1) sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) + sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) var token [16]byte server, err := NewCryptoSetupServer( sInitialStream, @@ -333,10 +337,12 @@ var _ = Describe("Crypto Setup TLS", func() { select { case sErr = <-sErrChan: default: + Expect(sHandshakeComplete).To(BeTrue()) } select { case cErr = <-cErrChan: default: + Expect(cHandshakeComplete).To(BeTrue()) } return cErr, sErr } @@ -408,6 +414,7 @@ var _ = Describe("Crypto Setup TLS", func() { cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b }) + cRunner.EXPECT().OnHandshakeComplete() client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, @@ -425,6 +432,7 @@ var _ = Describe("Crypto Setup TLS", func() { var token [16]byte sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b }) + sRunner.EXPECT().OnHandshakeComplete() sTransportParameters := &TransportParameters{ IdleTimeout: 0x1337 * time.Second, StatelessResetToken: &token, diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index f34421ab..09fdcf1a 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -30,6 +30,7 @@ type tlsExtensionHandler interface { type handshakeRunner interface { OnReceivedParams([]byte) + OnHandshakeComplete() OnError(error) DropKeys(protocol.EncryptionLevel) } diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go index 37fb494e..7911e69d 100644 --- a/internal/handshake/mock_handshake_runner_test.go +++ b/internal/handshake/mock_handshake_runner_test.go @@ -58,6 +58,18 @@ func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) } +// OnHandshakeComplete mocks base method +func (m *MockHandshakeRunner) OnHandshakeComplete() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnHandshakeComplete") +} + +// OnHandshakeComplete indicates an expected call of OnHandshakeComplete +func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) +} + // OnReceivedParams mocks base method func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 987a0423..9223ba7a 100644 --- a/session.go +++ b/session.go @@ -72,14 +72,16 @@ func (p *receivedPacket) Clone() *receivedPacket { } type handshakeRunner struct { - onReceivedParams func([]byte) - onError func(error) - dropKeys func(protocol.EncryptionLevel) + onReceivedParams func([]byte) + onError func(error) + dropKeys func(protocol.EncryptionLevel) + onHandshakeComplete func() } func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) } func (r *handshakeRunner) OnError(e error) { r.onError(e) } func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } +func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } type closeError struct { err error @@ -209,9 +211,10 @@ var newSession = func( conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.processTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, + onReceivedParams: s.processTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, logger, @@ -281,9 +284,10 @@ var newClientSession = func( conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.processTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, + onReceivedParams: s.processTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, logger, @@ -353,11 +357,8 @@ func (s *session) postSetup() error { func (s *session) run() error { defer s.ctxCancel() - go func() { - s.cryptoStreamHandler.RunHandshake() - // If an error occurred during the handshake, the crypto setup will already have called the close callback. - close(s.handshakeCompleteChan) - }() + go s.cryptoStreamHandler.RunHandshake() + if s.perspective == protocol.PerspectiveClient { select { case <-s.clientHelloWritten: diff --git a/session_test.go b/session_test.go index d38a22e0..0b0aa376 100644 --- a/session_test.go +++ b/session_test.go @@ -328,7 +328,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) Expect(sess.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ @@ -347,7 +347,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) Expect(sess.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ @@ -383,7 +383,7 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) runErr = sess.run() }() Eventually(areSessionsRunning).Should(BeTrue()) @@ -565,7 +565,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sessionRunner.EXPECT().Retire(gomock.Any()) @@ -590,7 +590,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(MatchError("PROTOCOL_VIOLATION: empty packet")) close(done) @@ -938,7 +938,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -965,7 +965,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -993,7 +993,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1021,7 +1021,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1043,7 +1043,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1071,7 +1071,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() Consistently(mconn.written).ShouldNot(Receive()) @@ -1106,7 +1106,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() Eventually(mconn.written).Should(Receive()) @@ -1121,25 +1121,7 @@ var _ = Describe("Session", func() { }) }) - It("calls the onHandshakeComplete callback when the handshake completes", func() { - packer.EXPECT().PackPacket().AnyTimes() - go func() { - defer GinkgoRecover() - sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) - cryptoSetup.EXPECT().RunHandshake() - sess.run() - }() - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) - // make sure the go routine returns - sessionRunner.EXPECT().Retire(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - cryptoSetup.EXPECT().Close() - Expect(sess.Close()).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) - }) - - It("sends a forward-secure packet when the handshake completes", func() { + It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) gomock.InOrder( sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()), @@ -1155,6 +1137,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() + close(sess.handshakeCompleteChan) sess.run() }() Eventually(done).Should(BeClosed()) @@ -1171,7 +1154,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) Expect(sess.run()).To(Succeed()) close(done) }() @@ -1188,7 +1171,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(MatchError(qerr.Error(0x1337, testErr.Error()))) close(done) @@ -1205,7 +1188,7 @@ var _ = Describe("Session", func() { It("errors if it can't unmarshal the TransportParameters", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("transport parameter")) @@ -1221,7 +1204,7 @@ var _ = Describe("Session", func() { It("process transport parameters received from the client", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() params := &handshake.TransportParameters{ @@ -1265,7 +1248,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1286,7 +1269,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1307,7 +1290,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() close(done) }() @@ -1335,7 +1318,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -1353,7 +1336,7 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -1375,7 +1358,7 @@ var _ = Describe("Session", func() { // and not on the last network activity go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) @@ -1396,6 +1379,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() sessionRunner.EXPECT().OnHandshakeComplete(sess) cryptoSetup.EXPECT().RunHandshake() + close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -1413,7 +1397,7 @@ var _ = Describe("Session", func() { sess.config.IdleTimeout = 30 * time.Second go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) @@ -1553,7 +1537,7 @@ var _ = Describe("Client Session", func() { sess.unpacker = unpacker go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }).AnyTimes() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} @@ -1630,7 +1614,7 @@ var _ = Describe("Client Session", func() { It("errors if it can't unmarshal the TransportParameters", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("transport parameter"))