read tls.Config properly in RSA signer

This commit is contained in:
Lucas Clemente
2016-05-08 22:42:11 +02:00
parent 6738f0eadf
commit 32cf5e4129
2 changed files with 84 additions and 3 deletions

View File

@@ -11,14 +11,16 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"strings"
"github.com/lucas-clemente/quic-go/utils"
)
// rsaSigner stores a key and a certificate for the server proof
type rsaSigner struct {
key *rsa.PrivateKey
cert *x509.Certificate
key *rsa.PrivateKey
cert *x509.Certificate
config *tls.Config
}
// NewRSASigner loads the key and cert from files
@@ -38,7 +40,7 @@ func NewRSASigner(tlsConfig *tls.Config) (Signer, error) {
return nil, errors.New("Only RSA private keys are supported for now")
}
return &rsaSigner{key: rsaKey, cert: x509Cert}, nil
return &rsaSigner{key: rsaKey, cert: x509Cert, config: tlsConfig}, nil
}
// SignServerProof signs CHLO and server config for use in the server proof
@@ -83,3 +85,28 @@ func (kd *rsaSigner) GetCertCompressed(sni string) []byte {
func (kd *rsaSigner) GetCertUncompressed(sni string) []byte {
return kd.cert.Raw
}
func (kd *rsaSigner) getCertForSNI(sni string) (*tls.Certificate, error) {
if kd.config.GetCertificate != nil {
cert, err := kd.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if err != nil {
return nil, err
}
if cert != nil {
return cert, nil
}
}
if len(kd.config.NameToCertificate) != 0 {
if cert, ok := kd.config.NameToCertificate[sni]; ok {
return cert, nil
}
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' })
if cert, ok := kd.config.NameToCertificate[wildcardSNI]; ok {
return cert, nil
}
}
if len(kd.config.Certificates) != 0 {
return &kd.config.Certificates[0], nil
}
return nil, errors.New("no matching certificate found")
}

View File

@@ -6,6 +6,7 @@ import (
"compress/zlib"
"crypto"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"github.com/lucas-clemente/quic-go/testdata"
@@ -41,4 +42,57 @@ var _ = Describe("ProofRsa", func() {
err = rsa.VerifyPSS(kd.(*rsaSigner).cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32})
Expect(err).ToNot(HaveOccurred())
})
Context("retrieving certificate", func() {
var (
signer *rsaSigner
config *tls.Config
cert tls.Certificate
)
BeforeEach(func() {
cert = testdata.GetCertificate()
config = &tls.Config{}
signer = &rsaSigner{config: config}
})
It("uses first certificate in config.Certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert, err := signer.getCertForSNI("")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries", func() {
config.NameToCertificate = map[string]*tls.Certificate{
"quic.clemente.io": &cert,
}
cert, err := signer.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries with wildcard", func() {
config.NameToCertificate = map[string]*tls.Certificate{
"*.clemente.io": &cert,
}
cert, err := signer.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses GetCertificate", func() {
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
return &cert, nil
}
cert, err := signer.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
})
})