From 382094ba68ee58203b937a5cea18e83c3c6c9314 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Aug 2019 16:32:01 +0700 Subject: [PATCH] add unit tests for session resumption --- internal/handshake/crypto_setup_test.go | 307 +++++++++++++++--------- 1 file changed, 187 insertions(+), 120 deletions(-) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 366fcb23..43b4654e 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -55,11 +55,12 @@ func (s *stream) Write(b []byte) (int, error) { var _ = Describe("Crypto Setup TLS", func() { var clientConf, serverConf *tls.Config - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { + initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) { chunkChan := make(chan chunk, 100) initialStream := newStream(chunkChan, protocol.EncryptionInitial) handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) - return chunkChan, initialStream, handshakeStream + oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT) + return chunkChan, initialStream, handshakeStream, oneRTTStream } BeforeEach(func() { @@ -116,7 +117,7 @@ var _ = Describe("Crypto Setup TLS", func() { sErrChan := make(chan error, 1) runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() + _, sInitialStream, sHandshakeStream, _ := initStreams() server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, @@ -152,7 +153,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("errors when a message is received at the wrong encryption level", func() { sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() + _, sInitialStream, sHandshakeStream, _ := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server, err := NewCryptoSetupServer( @@ -192,7 +193,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when handling a message fails", func() { sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() + _, sInitialStream, sHandshakeStream, _ := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server, err := NewCryptoSetupServer( @@ -228,7 +229,7 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream := initStreams() + _, sInitialStream, sHandshakeStream, _ := initStreams() server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, @@ -254,6 +255,8 @@ var _ = Describe("Crypto Setup TLS", func() { }) Context("doing the handshake", func() { + var testDone chan struct{} + generateCert := func() tls.Certificate { priv, err := rsa.GenerateKey(rand.Reader, 2048) Expect(err).ToNot(HaveOccurred()) @@ -273,6 +276,14 @@ var _ = Describe("Crypto Setup TLS", func() { } } + BeforeEach(func() { + testDone = make(chan struct{}) + }) + + AfterEach(func() { + close(testDone) + }) + handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) { done := make(chan struct{}) @@ -284,7 +295,7 @@ var _ = Describe("Crypto Setup TLS", func() { server.HandleMessage(c.data, c.encLevel) case c := <-sChunkChan: client.HandleMessage(c.data, c.encLevel) - case <-done: // handshake complete + case <-testDone: // handshake complete return } } @@ -300,9 +311,9 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) } - handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { + handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { var cHandshakeComplete bool - cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams() cErrChan := make(chan error, 1) cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) @@ -311,7 +322,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, + cOneRTTStream, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -323,7 +334,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err).ToNot(HaveOccurred()) var sHandshakeComplete bool - sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams() sErrChan := make(chan error, 1) sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) @@ -333,7 +344,7 @@ var _ = Describe("Crypto Setup TLS", func() { server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, + sOneRTTStream, protocol.ConnectionID{}, nil, &TransportParameters{StatelessResetToken: &token}, @@ -356,18 +367,18 @@ var _ = Describe("Crypto Setup TLS", func() { default: Expect(cHandshakeComplete).To(BeTrue()) } - return cErr, sErr + return client, cErr, server, sErr } It("handshakes", func() { - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) It("performs a HelloRetryRequst", func() { serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) @@ -375,14 +386,14 @@ var _ = Describe("Crypto Setup TLS", func() { It("handshakes with client auth", func() { clientConf.Certificates = []tls.Certificate{generateCert()} serverConf.ClientAuth = qtls.RequireAnyClientCert - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) It("signals when it has written the ClientHello", func() { runner := NewMockHandshakeRunner(mockCtrl) - cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() client, chChan, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, @@ -419,7 +430,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("receives transport parameters", func() { var cTransportParametersRcvd, sTransportParametersRcvd []byte - cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b }) @@ -438,7 +449,7 @@ var _ = Describe("Crypto Setup TLS", func() { ) Expect(err).ToNot(HaveOccurred()) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() var token [16]byte sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b }) @@ -478,116 +489,172 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(srvTP.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) }) - It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _, err := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - ioutil.Discard, - protocol.ConnectionID{}, - nil, - &TransportParameters{}, - cRunner, - clientConf, - &congestion.RTTStats{}, - utils.DefaultLogger.WithPrefix("client"), - ) - Expect(err).ToNot(HaveOccurred()) + Context("with session tickets", func() { + It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { + cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _, err := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + cRunner, + clientConf, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("client"), + ) + Expect(err).ToNot(HaveOccurred()) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - server, err := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - ioutil.Discard, - protocol.ConnectionID{}, - nil, - &TransportParameters{}, - sRunner, - serverConf, - &congestion.RTTStats{}, - utils.DefaultLogger.WithPrefix("server"), - ) - Expect(err).ToNot(HaveOccurred()) + sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + server, err := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + sRunner, + serverConf, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("server"), + ) + Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) - // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) - Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake")) + // inject an invalid session ticket + cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + qerr := err.(*qerr.QuicError) + Expect(qerr.IsCryptoError()).To(BeTrue()) + Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake")) + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.EncryptionHandshake) }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.EncryptionHandshake) - }) - It("errors when handling the NewSessionTicket fails", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnHandshakeComplete() - client, _, err := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - ioutil.Discard, - protocol.ConnectionID{}, - nil, - &TransportParameters{}, - cRunner, - clientConf, - &congestion.RTTStats{}, - utils.DefaultLogger.WithPrefix("client"), - ) - Expect(err).ToNot(HaveOccurred()) + It("errors when handling the NewSessionTicket fails", func() { + cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _, err := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + cRunner, + clientConf, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("client"), + ) + Expect(err).ToNot(HaveOccurred()) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnHandshakeComplete() - server, err := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - ioutil.Discard, - protocol.ConnectionID{}, - nil, - &TransportParameters{}, - sRunner, - serverConf, - &congestion.RTTStats{}, - utils.DefaultLogger.WithPrefix("server"), - ) - Expect(err).ToNot(HaveOccurred()) + sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + server, err := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + sRunner, + serverConf, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("server"), + ) + Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) - // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) + // inject an invalid session ticket + cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + qerr := err.(*qerr.QuicError) + Expect(qerr.IsCryptoError()).To(BeTrue()) + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.Encryption1RTT) + }) + + It("uses session resumption", func() { + csc := NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + }) + + It("doesn't use session resumption if the server disabled it", func() { + csc := NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + + serverConf.SessionTicketsDisabled = true + csc.EXPECT().Get(gomock.Any()).Return(state, true) + client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) }) - b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.Encryption1RTT) }) }) })