From 723f86c725c75463724d54f3db46f72a50f16ea1 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Thu, 2 Mar 2017 10:40:20 +0100 Subject: [PATCH] Don't use GetConfigForClient on go < 1.8 --- crypto/cert_chain.go | 11 +++-------- crypto/cert_chain_test.go | 8 +++++++- crypto/config_for_client_1.8.go | 14 ++++++++++++++ crypto/config_for_client_pre1.8.go | 9 +++++++++ 4 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 crypto/config_for_client_1.8.go create mode 100644 crypto/config_for_client_pre1.8.go diff --git a/crypto/cert_chain.go b/crypto/cert_chain.go index db7b5c35..96d0ecd0 100644 --- a/crypto/cert_chain.go +++ b/crypto/cert_chain.go @@ -57,14 +57,9 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) { func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { c := cc.config - if c.GetConfigForClient != nil { - var err error - c, err = c.GetConfigForClient(&tls.ClientHelloInfo{ - ServerName: sni, - }) - if err != nil { - return nil, err - } + c, err := maybeGetConfigForClient(c, sni) + if err != nil { + return nil, err } // The rest of this function is mostly copied from crypto/tls.getCertificate diff --git a/crypto/cert_chain_test.go b/crypto/cert_chain_test.go index d78f7982..a4ee57a3 100644 --- a/crypto/cert_chain_test.go +++ b/crypto/cert_chain_test.go @@ -5,6 +5,7 @@ import ( "compress/flate" "compress/zlib" "crypto/tls" + "reflect" "github.com/lucas-clemente/quic-go/testdata" @@ -129,11 +130,16 @@ var _ = Describe("Proof", func() { }) It("respects GetConfigForClient", func() { + if !reflect.ValueOf(tls.Config{}).FieldByName("GetConfigForClient").IsValid() { + // Pre 1.8, we don't have to do anything + return + } nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}} - config.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + l := func(chi *tls.ClientHelloInfo) (*tls.Config, error) { Expect(chi.ServerName).To(Equal("quic.clemente.io")) return nestedConfig, nil } + reflect.ValueOf(config).Elem().FieldByName("GetConfigForClient").Set(reflect.ValueOf(l)) resultCert, err := cc.getCertForSNI("quic.clemente.io") Expect(err).NotTo(HaveOccurred()) Expect(*resultCert).To(Equal(cert)) diff --git a/crypto/config_for_client_1.8.go b/crypto/config_for_client_1.8.go new file mode 100644 index 00000000..452fe0e0 --- /dev/null +++ b/crypto/config_for_client_1.8.go @@ -0,0 +1,14 @@ +// +build go1.8 + +package crypto + +import "crypto/tls" + +func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { + if c.GetConfigForClient == nil { + return c, nil + } + return c.GetConfigForClient(&tls.ClientHelloInfo{ + ServerName: sni, + }) +} diff --git a/crypto/config_for_client_pre1.8.go b/crypto/config_for_client_pre1.8.go new file mode 100644 index 00000000..612b94a1 --- /dev/null +++ b/crypto/config_for_client_pre1.8.go @@ -0,0 +1,9 @@ +// +build !go1.8 + +package crypto + +import "crypto/tls" + +func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { + return c, nil +}