diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 0df8f993..e4a53991 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -305,6 +305,28 @@ var _ = Describe("Handshake tests", func() { checkContextFromChan(tracerContextChan, false) }) + It("fails the handshake when tls.Config.GetConfigForClient errors", func() { + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{Conn: udpConn} + addTracer(tr) + defer tr.Close() + tlsConf := &tls.Config{} + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, errors.New("nope") + } + ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + _, err = quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + }) + Context("using different cipher suites", func() { for n, id := range map[string]uint16{ "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, @@ -881,10 +903,10 @@ var _ = Describe("Handshake tests", func() { tlsConf, nil, ) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: "tls: invalid NextProtos value", - })) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("tls: invalid NextProtos value")) Consistently(packetChan).ShouldNot(Receive()) ln.Close() Eventually(done).Should(BeClosed()) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 9347eef6..0fb75dc8 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -624,8 +624,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState { } func wrapError(err error) error { - // alert 80 is an internal error - if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { + if alertErr := tls.AlertError(0); errors.As(err, &alertErr) { return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 140238d9..18edceb1 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "errors" "math/big" "net" "time" @@ -96,10 +97,11 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - Expect(cl.StartHandshake(context.Background())).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.InternalError, - ErrorMessage: "tls: invalid NextProtos value", - })) + var terr *qerr.TransportError + err := cl.StartHandshake(context.Background()) + Expect(errors.As(err, &terr)).To(BeTrue()) + Expect(terr.ErrorCode).To(BeEquivalentTo(0x100 + 0x50)) + Expect(err.Error()).To(ContainSubstring("tls: invalid NextProtos value")) }) It("errors when a message is received at the wrong encryption level", func() {