From 32cf5e4129338a850cfa3eea03fe6b996d92a6ef Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Sun, 8 May 2016 22:42:11 +0200 Subject: [PATCH] read tls.Config properly in RSA signer --- crypto/proof_rsa.go | 33 +++++++++++++++++++++--- crypto/proof_rsa_test.go | 54 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/crypto/proof_rsa.go b/crypto/proof_rsa.go index 7b971417a..7211d6ad4 100644 --- a/crypto/proof_rsa.go +++ b/crypto/proof_rsa.go @@ -11,14 +11,16 @@ import ( "crypto/tls" "crypto/x509" "errors" + "strings" "github.com/lucas-clemente/quic-go/utils" ) // rsaSigner stores a key and a certificate for the server proof type rsaSigner struct { - key *rsa.PrivateKey - cert *x509.Certificate + key *rsa.PrivateKey + cert *x509.Certificate + config *tls.Config } // NewRSASigner loads the key and cert from files @@ -38,7 +40,7 @@ func NewRSASigner(tlsConfig *tls.Config) (Signer, error) { return nil, errors.New("Only RSA private keys are supported for now") } - return &rsaSigner{key: rsaKey, cert: x509Cert}, nil + return &rsaSigner{key: rsaKey, cert: x509Cert, config: tlsConfig}, nil } // SignServerProof signs CHLO and server config for use in the server proof @@ -83,3 +85,28 @@ func (kd *rsaSigner) GetCertCompressed(sni string) []byte { func (kd *rsaSigner) GetCertUncompressed(sni string) []byte { return kd.cert.Raw } + +func (kd *rsaSigner) getCertForSNI(sni string) (*tls.Certificate, error) { + if kd.config.GetCertificate != nil { + cert, err := kd.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) + if err != nil { + return nil, err + } + if cert != nil { + return cert, nil + } + } + if len(kd.config.NameToCertificate) != 0 { + if cert, ok := kd.config.NameToCertificate[sni]; ok { + return cert, nil + } + wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) + if cert, ok := kd.config.NameToCertificate[wildcardSNI]; ok { + return cert, nil + } + } + if len(kd.config.Certificates) != 0 { + return &kd.config.Certificates[0], nil + } + return nil, errors.New("no matching certificate found") +} diff --git a/crypto/proof_rsa_test.go b/crypto/proof_rsa_test.go index eb123ea37..f73e9f8bb 100644 --- a/crypto/proof_rsa_test.go +++ b/crypto/proof_rsa_test.go @@ -6,6 +6,7 @@ import ( "compress/zlib" "crypto" "crypto/rsa" + "crypto/tls" "crypto/x509" "github.com/lucas-clemente/quic-go/testdata" @@ -41,4 +42,57 @@ var _ = Describe("ProofRsa", func() { err = rsa.VerifyPSS(kd.(*rsaSigner).cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32}) Expect(err).ToNot(HaveOccurred()) }) + + Context("retrieving certificate", func() { + var ( + signer *rsaSigner + config *tls.Config + cert tls.Certificate + ) + + BeforeEach(func() { + cert = testdata.GetCertificate() + config = &tls.Config{} + signer = &rsaSigner{config: config} + }) + + It("uses first certificate in config.Certificates", func() { + config.Certificates = []tls.Certificate{cert} + cert, err := signer.getCertForSNI("") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses NameToCertificate entries", func() { + config.NameToCertificate = map[string]*tls.Certificate{ + "quic.clemente.io": &cert, + } + cert, err := signer.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses NameToCertificate entries with wildcard", func() { + config.NameToCertificate = map[string]*tls.Certificate{ + "*.clemente.io": &cert, + } + cert, err := signer.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses GetCertificate", func() { + config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + Expect(clientHello.ServerName).To(Equal("quic.clemente.io")) + return &cert, nil + } + cert, err := signer.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + }) })