From d309060cded7aea6c3620c1238095e9a72175ecc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 27 Oct 2023 12:22:20 +0700 Subject: [PATCH] handshake: clone the tls.Config returned by GetConfigForClient (#4133) We modify this tls.Config, so we should clone it first. Otherwise, this could cause conflicts with how the application is using that config. --- internal/handshake/crypto_setup.go | 1 + internal/handshake/crypto_setup_test.go | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 7090fcca..e15aad54 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -147,6 +147,7 @@ func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} c, err := gcfc(info) if c != nil { + c = c.Clone() // We're returning a tls.Config here, so we need to apply this recursively. addConnToClientHelloInfo(c, localAddr, remoteAddr) } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index b32d12a3..21278339 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509/pkix" "math/big" "net" + "reflect" "runtime" "strings" "time" @@ -148,15 +149,17 @@ var _ = Describe("Crypto Setup TLS", func() { It("wraps GetConfigForClient, recursively", func() { var localAddr, remoteAddr net.Addr tlsConf := &tls.Config{} + var innerConf *tls.Config + getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + cert := generateCert() + return &cert, nil + } tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - conf := tlsConf.Clone() - conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - cert := generateCert() - return &cert, nil - } - return conf, nil + innerConf = tlsConf.Clone() + innerConf.GetCertificate = getCert + return innerConf, nil } addConnToClientHelloInfo(tlsConf, local, remote) conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) @@ -165,6 +168,8 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err).ToNot(HaveOccurred()) Expect(localAddr).To(Equal(local)) Expect(remoteAddr).To(Equal(remote)) + // make sure that the tls.Config returned by GetConfigForClient isn't modified + Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) }) })