diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 1ad306be0..601a23fe0 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -389,7 +389,8 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { } return true default: - panic("unexpected handshake message") + h.messageErrChan <- qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("unexpected handshake message: %d", msgType)) + return false } } @@ -445,7 +446,8 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { h.conn.HandlePostHandshakeMessage() return false default: - panic("unexpected handshake message: ") + h.messageErrChan <- qerr.CryptoError(alertUnexpectedMessage, fmt.Sprintf("unexpected handshake message: %d", msgType)) + return false } } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index cd807a8dd..2a988e1f8 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -142,7 +142,7 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) }) - It("returns Handshake() when handling a message fails", func() { + It("returns Handshake() when a message is received at the wrong encryption level", func() { _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupServer( sInitialStream, @@ -175,6 +175,39 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) }) + It("returns Handshake() when handling a message fails", func() { + _, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + func([]byte) {}, + func(protocol.EncryptionLevel) {}, + testdata.GetTLSConfig(), + utils.DefaultLogger.WithPrefix("server"), + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := server.RunHandshake() + Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) + qerr := err.(*qerr.QuicError) + Expect(qerr.IsCryptoError()).To(BeTrue()) + Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + Expect(err.Error()).To(ContainSubstring("unexpected handshake message")) + close(done) + }() + + fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) + server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level + Eventually(done).Should(BeClosed()) + }) + It("returns Handshake() when it is closed", func() { _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupServer(