diff --git a/crypto/cert_chain.go b/crypto/cert_chain.go index 69c00d5f..db7b5c35 100644 --- a/crypto/cert_chain.go +++ b/crypto/cert_chain.go @@ -57,6 +57,15 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) { func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { c := cc.config + if c.GetConfigForClient != nil { + var err error + c, err = c.GetConfigForClient(&tls.ClientHelloInfo{ + ServerName: sni, + }) + if err != nil { + return nil, err + } + } // The rest of this function is mostly copied from crypto/tls.getCertificate if c.GetCertificate != nil { diff --git a/crypto/cert_chain_test.go b/crypto/cert_chain_test.go index 8216d8d3..d78f7982 100644 --- a/crypto/cert_chain_test.go +++ b/crypto/cert_chain_test.go @@ -127,5 +127,16 @@ var _ = Describe("Proof", func() { _, err := cc.GetLeafCert("invalid domain") Expect(err).To(MatchError(errNoMatchingCertificate)) }) + + It("respects GetConfigForClient", func() { + nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + config.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + Expect(chi.ServerName).To(Equal("quic.clemente.io")) + return nestedConfig, nil + } + resultCert, err := cc.getCertForSNI("quic.clemente.io") + Expect(err).NotTo(HaveOccurred()) + Expect(*resultCert).To(Equal(cert)) + }) }) })