diff --git a/crypto/cert_chain.go b/crypto/cert_chain.go index 292ceca29..0fd905a77 100644 --- a/crypto/cert_chain.go +++ b/crypto/cert_chain.go @@ -20,9 +20,11 @@ type certChain struct { var _ CertChain = &certChain{} +var errNoMatchingCertificate = errors.New("no matching certificate found") + // NewCertChain loads the key and cert from files -func NewCertChain(tlsConfig *tls.Config) (CertChain, error) { - return &certChain{config: tlsConfig}, nil +func NewCertChain(tlsConfig *tls.Config) CertChain { + return &certChain{config: tlsConfig} } // SignServerProof signs CHLO and server config for use in the server proof @@ -78,5 +80,5 @@ func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { return &c.config.Certificates[0], nil } - return nil, errors.New("no matching certificate found") + return nil, errNoMatchingCertificate } diff --git a/crypto/cert_chain_test.go b/crypto/cert_chain_test.go index 670beaf5e..6ad5ada8d 100644 --- a/crypto/cert_chain_test.go +++ b/crypto/cert_chain_test.go @@ -6,97 +6,124 @@ import ( "compress/zlib" "crypto/tls" + "github.com/lucas-clemente/quic-go/testdata" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Proof", func() { + var ( + cc *certChain + config *tls.Config + cert tls.Certificate + ) - It("compresses certs", func() { - cert := []byte{0xde, 0xca, 0xfb, 0xad} - certZlib := &bytes.Buffer{} - z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) - Expect(err).ToNot(HaveOccurred()) - z.Write([]byte{0x04, 0x00, 0x00, 0x00}) - z.Write(cert) - z.Close() - kd := &certChain{ - config: &tls.Config{ - Certificates: []tls.Certificate{ - {Certificate: [][]byte{cert}}, - }, - }, - } - certCompressed, err := kd.GetCertsCompressed("", nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(certCompressed).To(Equal(append([]byte{ - 0x01, 0x00, - 0x08, 0x00, 0x00, 0x00, - }, certZlib.Bytes()...))) + BeforeEach(func() { + cert = testdata.GetCertificate() + config = &tls.Config{} + cc = NewCertChain(config).(*certChain) }) - // Context("retrieving certificate", func() { - // var ( - // signer *proofSource - // config *tls.Config - // cert tls.Certificate - // ) - // - // BeforeEach(func() { - // cert = testdata.GetCertificate() - // config = &tls.Config{} - // signer = &proofSource{config: config} - // }) - // - // It("errors without certificates", func() { - // _, err := signer.getCertForSNI("") - // Expect(err).To(MatchError("no matching certificate found")) - // }) - // - // It("uses first certificate in config.Certificates", func() { - // config.Certificates = []tls.Certificate{cert} - // cert, err := signer.getCertForSNI("") - // Expect(err).ToNot(HaveOccurred()) - // Expect(cert.PrivateKey).ToNot(BeNil()) - // Expect(cert.Certificate[0]).ToNot(BeNil()) - // }) - // - // It("uses NameToCertificate entries", func() { - // config.NameToCertificate = map[string]*tls.Certificate{ - // "quic.clemente.io": &cert, - // } - // cert, err := signer.getCertForSNI("quic.clemente.io") - // Expect(err).ToNot(HaveOccurred()) - // Expect(cert.PrivateKey).ToNot(BeNil()) - // Expect(cert.Certificate[0]).ToNot(BeNil()) - // }) - // - // It("uses NameToCertificate entries with wildcard", func() { - // config.NameToCertificate = map[string]*tls.Certificate{ - // "*.clemente.io": &cert, - // } - // cert, err := signer.getCertForSNI("quic.clemente.io") - // Expect(err).ToNot(HaveOccurred()) - // Expect(cert.PrivateKey).ToNot(BeNil()) - // Expect(cert.Certificate[0]).ToNot(BeNil()) - // }) - // - // It("uses GetCertificate", func() { - // config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - // Expect(clientHello.ServerName).To(Equal("quic.clemente.io")) - // return &cert, nil - // } - // cert, err := signer.getCertForSNI("quic.clemente.io") - // Expect(err).ToNot(HaveOccurred()) - // Expect(cert.PrivateKey).ToNot(BeNil()) - // Expect(cert.Certificate[0]).ToNot(BeNil()) - // }) - // - // It("gets leaf certificates", func() { - // config.Certificates = []tls.Certificate{cert} - // cert2, err := signer.GetLeafCert("") - // Expect(err).ToNot(HaveOccurred()) - // Expect(cert2).To(Equal(cert.Certificate[0])) - // }) - // }) + Context("certificate compression", func() { + It("compresses certs", func() { + cert := []byte{0xde, 0xca, 0xfb, 0xad} + certZlib := &bytes.Buffer{} + z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) + Expect(err).ToNot(HaveOccurred()) + z.Write([]byte{0x04, 0x00, 0x00, 0x00}) + z.Write(cert) + z.Close() + kd := &certChain{ + config: &tls.Config{ + Certificates: []tls.Certificate{ + {Certificate: [][]byte{cert}}, + }, + }, + } + certCompressed, err := kd.GetCertsCompressed("", nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(certCompressed).To(Equal(append([]byte{ + 0x01, 0x00, + 0x08, 0x00, 0x00, 0x00, + }, certZlib.Bytes()...))) + }) + + It("errors when it can't retrieve a certificate", func() { + _, err := cc.GetCertsCompressed("invalid domain", nil, nil) + Expect(err).To(MatchError(errNoMatchingCertificate)) + }) + }) + + Context("signing server configs", func() { + It("errors when it can't retrieve a certificate for the requested SNI", func() { + _, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg")) + Expect(err).To(MatchError(errNoMatchingCertificate)) + }) + + It("signs the server config", func() { + config.Certificates = []tls.Certificate{cert} + proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg")) + Expect(err).ToNot(HaveOccurred()) + Expect(proof).ToNot(BeEmpty()) + }) + }) + + Context("retrieving certificates", func() { + It("errors without certificates", func() { + _, err := cc.getCertForSNI("") + Expect(err).To(MatchError(errNoMatchingCertificate)) + }) + + It("uses first certificate in config.Certificates", func() { + config.Certificates = []tls.Certificate{cert} + cert, err := cc.getCertForSNI("") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses NameToCertificate entries", func() { + config.NameToCertificate = map[string]*tls.Certificate{ + "quic.clemente.io": &cert, + } + cert, err := cc.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses NameToCertificate entries with wildcard", func() { + config.NameToCertificate = map[string]*tls.Certificate{ + "*.clemente.io": &cert, + } + cert, err := cc.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("uses GetCertificate", func() { + config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + Expect(clientHello.ServerName).To(Equal("quic.clemente.io")) + return &cert, nil + } + cert, err := cc.getCertForSNI("quic.clemente.io") + Expect(err).ToNot(HaveOccurred()) + Expect(cert.PrivateKey).ToNot(BeNil()) + Expect(cert.Certificate[0]).ToNot(BeNil()) + }) + + It("gets leaf certificates", func() { + config.Certificates = []tls.Certificate{cert} + cert2, err := cc.GetLeafCert("") + Expect(err).ToNot(HaveOccurred()) + Expect(cert2).To(Equal(cert.Certificate[0])) + }) + + It("errors when it can't retrieve a leaf certificate", func() { + _, err := cc.GetLeafCert("invalid domain") + Expect(err).To(MatchError(errNoMatchingCertificate)) + }) + }) }) diff --git a/server.go b/server.go index aa1a0e663..fb8588cf4 100644 --- a/server.go +++ b/server.go @@ -44,10 +44,7 @@ type Server struct { // NewServer makes a new server func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) { - certChain, err := crypto.NewCertChain(tlsConfig) - if err != nil { - return nil, err - } + certChain := crypto.NewCertChain(tlsConfig) kex, err := crypto.NewCurve25519KEX() if err != nil { diff --git a/session_test.go b/session_test.go index e9473f09a..9bf488e9b 100644 --- a/session_test.go +++ b/session_test.go @@ -129,8 +129,7 @@ var _ = Describe("Session", func() { streamCallbackCalled = false closeCallbackCalled = false - certChain, err := crypto.NewCertChain(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) + certChain := crypto.NewCertChain(testdata.GetTLSConfig()) kex, err := crypto.NewCurve25519KEX() Expect(err).NotTo(HaveOccurred()) scfg, err := handshake.NewServerConfig(kex, certChain)