From 17634d2fe5d1179c15e914e056fced07ad90c395 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 1 Jun 2019 15:29:20 +0800 Subject: [PATCH] error when receiving a post-handshake message with wrong encryption level --- internal/handshake/crypto_setup.go | 52 +++++++++------- internal/handshake/crypto_setup_test.go | 79 +++++++++++++++++++++---- 2 files changed, 96 insertions(+), 35 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 91b076a97..1d343f455 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -76,9 +76,8 @@ type cryptoSetup struct { runner handshakeRunner + closed bool alertChan chan uint8 - // HandleData() sends errors on the messageErrChan - messageErrChan chan error // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns handshakeDone chan struct{} // is closed when Close() is called @@ -211,7 +210,6 @@ func newCryptoSetup( perspective: perspective, handshakeDone: make(chan struct{}), alertChan: make(chan uint8), - messageErrChan: make(chan error, 1), clientHelloWrittenChan: make(chan struct{}), messageChan: make(chan []byte, 100), receivedReadKey: make(chan struct{}), @@ -269,22 +267,29 @@ func (h *cryptoSetup) RunHandshake() { case <-handshakeComplete: // return when the handshake is done h.runner.OnHandshakeComplete() case <-h.closeChan: - // wait until the Handshake() go routine has returned close(h.messageChan) + // wait until the Handshake() go routine has returned + <-h.handshakeDone case alert := <-h.alertChan: handshakeErr := <-handshakeErrChan - h.runner.OnError(qerr.CryptoError(alert, handshakeErr.Error())) - case err := <-h.messageErrChan: - // If the handshake errored because of an error that occurred during HandleData(), - // that error message will be more useful than the error message generated by Handshake(). - // Close the message chan that qtls is receiving messages from. - // This will make qtls.Handshake() return. - close(h.messageChan) - h.runner.OnError(err) + h.onError(alert, handshakeErr.Error()) } } +func (h *cryptoSetup) onError(alert uint8, message string) { + + h.runner.OnError(qerr.CryptoError(alert, message)) +} + func (h *cryptoSetup) Close() error { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.closed { + return nil + } + h.closed = true + close(h.closeChan) // wait until qtls.Handshake() actually returned <-h.handshakeDone @@ -298,10 +303,13 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev msgType := messageType(data[0]) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - h.messageErrChan <- err + h.onError(alertUnexpectedMessage, err.Error()) return false } h.messageChan <- data + if encLevel == protocol.Encryption1RTT { + h.handlePostHandshakeMessage(data) + } switch h.perspective { case protocol.PerspectiveClient: return h.handleMessageForClient(msgType) @@ -327,10 +335,10 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco case typeNewSessionTicket: expected = protocol.Encryption1RTT default: - return qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("unexpected handshake message: %d", msgType)) + return fmt.Errorf("unexpected handshake message: %d", msgType) } if encLevel != expected { - return qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)) + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) } return nil } @@ -380,7 +388,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { } return true default: - h.messageErrChan <- qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("unexpected handshake message: %d", msgType)) + // unexpected message return false } } @@ -432,17 +440,15 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { return false } return true - case typeNewSessionTicket: - <-h.handshakeDone // don't process session tickets before the handshake has completed - h.handleNewSessionTicket() - return false default: - h.messageErrChan <- qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("unexpected handshake message: %d", msgType)) return false } } -func (h *cryptoSetup) handleNewSessionTicket() { +func (h *cryptoSetup) handlePostHandshakeMessage(data []byte) { + // make sure the handshake has already completed + <-h.handshakeDone + done := make(chan struct{}) defer close(done) @@ -460,7 +466,7 @@ func (h *cryptoSetup) handleNewSessionTicket() { }() if err := h.conn.HandlePostHandshakeMessage(); err != nil { - h.runner.OnError(qerr.CryptoError(<-alertChan, err.Error())) + h.onError(<-alertChan, err.Error()) } } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index f5549e676..1a23a2ba3 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -107,7 +107,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(cconf.ReceivedExtensions).ToNot(BeNil()) }) - It("returns Handshake() when an error occurs", func() { + It("returns Handshake() when an error occurs in qtls", func() { sErrChan := make(chan error, 1) runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) @@ -144,7 +144,7 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) }) - It("returns Handshake() when a message is received at the wrong encryption level", func() { + It("errors when a message is received at the wrong encryption level", func() { sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) @@ -166,18 +166,20 @@ var _ = Describe("Crypto Setup TLS", func() { go func() { defer GinkgoRecover() server.RunHandshake() - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) - Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake")) close(done) }() fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + Expect(sErrChan).To(Receive(&err)) + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + qerr := err.(*qerr.QuicError) + Expect(qerr.IsCryptoError()).To(BeTrue()) + Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake")) + + // make the go routine return + Expect(server.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -209,7 +211,6 @@ var _ = Describe("Crypto Setup TLS", func() { qerr := err.(*qerr.QuicError) Expect(qerr.IsCryptoError()).To(BeTrue()) Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - Expect(err.Error()).To(ContainSubstring("unexpected handshake message")) close(done) }() @@ -403,8 +404,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(len(ch.data) - 4).To(Equal(length)) // make the go routine return - runner.EXPECT().OnError(gomock.Any()) - client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) + Expect(client.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -467,6 +467,61 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(srvTP.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) }) + It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, _, err := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + cRunner, + clientConf, + utils.DefaultLogger.WithPrefix("client"), + ) + Expect(err).ToNot(HaveOccurred()) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + server, err := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + sRunner, + testdata.GetTLSConfig(), + utils.DefaultLogger.WithPrefix("server"), + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + // inject an invalid session ticket + cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + qerr := err.(*qerr.QuicError) + Expect(qerr.IsCryptoError()).To(BeTrue()) + Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake")) + }) + b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) + client.HandleMessage(b, protocol.EncryptionHandshake) + }) + It("errors when handling the NewSessionTicket fails", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl)