add tests for certChain, simplify constructor

This commit is contained in:
Marten Seemann
2016-11-18 13:38:33 +07:00
parent bb1af0db1e
commit 6913f5ae75
4 changed files with 119 additions and 94 deletions

View File

@@ -20,9 +20,11 @@ type certChain struct {
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) (CertChain, error) {
return &certChain{config: tlsConfig}, nil
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
@@ -78,5 +80,5 @@ func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
return &c.config.Certificates[0], nil
}
return nil, errors.New("no matching certificate found")
return nil, errNoMatchingCertificate
}

View File

@@ -6,97 +6,124 @@ import (
"compress/zlib"
"crypto/tls"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Proof", func() {
var (
cc *certChain
config *tls.Config
cert tls.Certificate
)
It("compresses certs", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
kd := &certChain{
config: &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{cert}},
},
},
}
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(certCompressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
BeforeEach(func() {
cert = testdata.GetCertificate()
config = &tls.Config{}
cc = NewCertChain(config).(*certChain)
})
// Context("retrieving certificate", func() {
// var (
// signer *proofSource
// config *tls.Config
// cert tls.Certificate
// )
//
// BeforeEach(func() {
// cert = testdata.GetCertificate()
// config = &tls.Config{}
// signer = &proofSource{config: config}
// })
//
// It("errors without certificates", func() {
// _, err := signer.getCertForSNI("")
// Expect(err).To(MatchError("no matching certificate found"))
// })
//
// It("uses first certificate in config.Certificates", func() {
// config.Certificates = []tls.Certificate{cert}
// cert, err := signer.getCertForSNI("")
// Expect(err).ToNot(HaveOccurred())
// Expect(cert.PrivateKey).ToNot(BeNil())
// Expect(cert.Certificate[0]).ToNot(BeNil())
// })
//
// It("uses NameToCertificate entries", func() {
// config.NameToCertificate = map[string]*tls.Certificate{
// "quic.clemente.io": &cert,
// }
// cert, err := signer.getCertForSNI("quic.clemente.io")
// Expect(err).ToNot(HaveOccurred())
// Expect(cert.PrivateKey).ToNot(BeNil())
// Expect(cert.Certificate[0]).ToNot(BeNil())
// })
//
// It("uses NameToCertificate entries with wildcard", func() {
// config.NameToCertificate = map[string]*tls.Certificate{
// "*.clemente.io": &cert,
// }
// cert, err := signer.getCertForSNI("quic.clemente.io")
// Expect(err).ToNot(HaveOccurred())
// Expect(cert.PrivateKey).ToNot(BeNil())
// Expect(cert.Certificate[0]).ToNot(BeNil())
// })
//
// It("uses GetCertificate", func() {
// config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
// return &cert, nil
// }
// cert, err := signer.getCertForSNI("quic.clemente.io")
// Expect(err).ToNot(HaveOccurred())
// Expect(cert.PrivateKey).ToNot(BeNil())
// Expect(cert.Certificate[0]).ToNot(BeNil())
// })
//
// It("gets leaf certificates", func() {
// config.Certificates = []tls.Certificate{cert}
// cert2, err := signer.GetLeafCert("")
// Expect(err).ToNot(HaveOccurred())
// Expect(cert2).To(Equal(cert.Certificate[0]))
// })
// })
Context("certificate compression", func() {
It("compresses certs", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
kd := &certChain{
config: &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{cert}},
},
},
}
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(certCompressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("errors when it can't retrieve a certificate", func() {
_, err := cc.GetCertsCompressed("invalid domain", nil, nil)
Expect(err).To(MatchError(errNoMatchingCertificate))
})
})
Context("signing server configs", func() {
It("errors when it can't retrieve a certificate for the requested SNI", func() {
_, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg"))
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("signs the server config", func() {
config.Certificates = []tls.Certificate{cert}
proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg"))
Expect(err).ToNot(HaveOccurred())
Expect(proof).ToNot(BeEmpty())
})
})
Context("retrieving certificates", func() {
It("errors without certificates", func() {
_, err := cc.getCertForSNI("")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("uses first certificate in config.Certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert, err := cc.getCertForSNI("")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries", func() {
config.NameToCertificate = map[string]*tls.Certificate{
"quic.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries with wildcard", func() {
config.NameToCertificate = map[string]*tls.Certificate{
"*.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses GetCertificate", func() {
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
return &cert, nil
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("gets leaf certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert2, err := cc.GetLeafCert("")
Expect(err).ToNot(HaveOccurred())
Expect(cert2).To(Equal(cert.Certificate[0]))
})
It("errors when it can't retrieve a leaf certificate", func() {
_, err := cc.GetLeafCert("invalid domain")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
})
})

View File

@@ -44,10 +44,7 @@ type Server struct {
// NewServer makes a new server
func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
certChain, err := crypto.NewCertChain(tlsConfig)
if err != nil {
return nil, err
}
certChain := crypto.NewCertChain(tlsConfig)
kex, err := crypto.NewCurve25519KEX()
if err != nil {

View File

@@ -129,8 +129,7 @@ var _ = Describe("Session", func() {
streamCallbackCalled = false
closeCallbackCalled = false
certChain, err := crypto.NewCertChain(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
certChain := crypto.NewCertChain(testdata.GetTLSConfig())
kex, err := crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred())
scfg, err := handshake.NewServerConfig(kex, certChain)