From e33f7d0fb95d488b6ccb8d675baed840eb34e831 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 May 2020 15:58:40 +0700 Subject: [PATCH] add integration tests using a very long certificate chain This will trigger the amplification protection. --- integrationtests/self/handshake_drop_test.go | 91 ++++++------ integrationtests/self/handshake_test.go | 46 ++++--- integrationtests/self/self_suite_test.go | 138 ++++++++++++++----- 3 files changed, 185 insertions(+), 90 deletions(-) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 92e9951ef..df062f93d 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -2,6 +2,7 @@ package self_test import ( "context" + "crypto/tls" "fmt" mrand "math/rand" "net" @@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() { const timeout = 10 * time.Minute - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { conf := getQuicConfigForServer(&quic.Config{ MaxIdleTimeout: timeout, HandshakeTimeout: timeout, @@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() { if !doRetry { conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true } } + var tlsConf *tls.Config + if longCertChain { + tlsConf = getTLSConfigWithLongCertChain() + } else { + tlsConf = getTLSConfig() + } var err error - ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf) + ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ @@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() { } Context(desc, func() { - for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { - app := a + for _, lcc := range []bool{false, true} { + longCertChain := lcc - Context(app.name, func() { - It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) - } - return p == 1 && d.Is(direction) - }, doRetry, version) - app.run(version) - }) + Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { + for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { + app := a - It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) - } - return p == 2 && d.Is(direction) - }, doRetry, version) - app.run(version) - }) + Context(app.name, func() { + It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 1 && d.Is(direction) + }, doRetry, longCertChain, version) + app.run(version) + }) - It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - return d.Is(direction) && stochasticDropper(3) - }, doRetry, version) - app.run(version) - }) + It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 2 && d.Is(direction) + }, doRetry, longCertChain, version) + app.run(version) + }) + + It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + return d.Is(direction) && stochasticDropper(3) + }, doRetry, longCertChain, version) + app.run(version) + }) + }) + } }) } }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index ff5c9455a..e86e25958 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() { server quic.Listener serverConfig *quic.Config acceptStopped chan struct{} - tlsServerConf *tls.Config ) BeforeEach(func() { server = nil acceptStopped = make(chan struct{}) serverConfig = getQuicConfigForServer(nil) - tlsServerConf = getTLSConfig() }) AfterEach(func() { @@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() { } }) - runServer := func() quic.Listener { + runServer := func(tlsConf *tls.Config) { var err error // start the server - server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig) Expect(err).ToNot(HaveOccurred()) go func() { @@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() { } } }() - return server } if !israce.Enabled { @@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() { // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - server := runServer() + runServer(getTLSConfig()) defer server.Close() sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() { // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = supportedVersions - server := runServer() + runServer(getTLSConfig()) defer server.Close() sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() { suiteID := id It(fmt.Sprintf("using %s", name), func() { - tlsServerConf.CipherSuites = []uint16{suiteID} - ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + tlsConf := getTLSConfig() + tlsConf.CipherSuites = []uint16{suiteID} + ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() go func() { defer GinkgoRecover() @@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() { } }) - Context("Certifiate validation", func() { + Context("Certificate validation", func() { for _, v := range protocol.SupportedVersions { version := v @@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() { clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}) }) - JustBeforeEach(func() { - runServer() - }) - It("accepts the certificate", func() { + runServer(getTLSConfig()) _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) }) + It("works with a long certificate chain", func() { + runServer(getTLSConfigWithLongCertChain()) + _, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + }) + It("errors if the server name doesn't match", func() { + runServer(getTLSConfig()) _, err := quic.DialAddr( fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() { }) It("fails the handshake if the client fails to provide the requested client cert", func() { - tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert + tlsConf := getTLSConfig() + tlsConf.ClientAuth = tls.RequireAndVerifyClientCert + runServer(tlsConf) + sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() { }) It("uses the ServerName in the tls.Config", func() { + runServer(getTLSConfig()) tlsConf := getTLSClientConfig() tlsConf.ServerName = "localhost" _, err := quic.DialAddr( @@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() { Context("ALPN", func() { It("negotiates an application protocol", func() { - ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() { }) It("errors if application protocol negotiation fails", func() { - server := runServer() + runServer(getTLSConfig()) tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{"foobar"} @@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR")) Expect(err.Error()).To(ContainSubstring("no application protocol")) - Expect(server.Close()).To(Succeed()) }) }) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index e5e71bac5..cd0821476 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -85,10 +85,9 @@ var ( logBuf *syncedBuffer enableQlog bool - caPrivateKey *rsa.PrivateKey - ca *x509.Certificate - leafPrivateKey *rsa.PrivateKey - leafCert *x509.Certificate + tlsConfig *tls.Config + tlsConfigLongChain *tls.Config + tlsClientConfig *tls.Config ) // read the logfile command line flag @@ -97,16 +96,37 @@ func init() { flag.StringVar(&logFileName, "logfile", "", "log file") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") - if err := generateCA(); err != nil { + ca, caPrivateKey, err := generateCA() + if err != nil { panic(err) } - if err := generateCertChain(); err != nil { + leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) + if err != nil { panic(err) } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{tls.Certificate{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{alpn}, + } + tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey) + if err != nil { + panic(err) + } + tlsConfigLongChain = tlsConfLongChain + + root := x509.NewCertPool() + root.AddCert(ca) + tlsClientConfig = &tls.Config{ + RootCAs: root, + NextProtos: []string{alpn}, + } } -func generateCA() error { - caCert := &x509.Certificate{ +func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: pkix.Name{}, NotBefore: time.Now(), @@ -116,21 +136,23 @@ func generateCA() error { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } - var err error - caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, nil, err } - caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) if err != nil { - return err + return nil, nil, err } - ca, err = x509.ParseCertificate(caBytes) - return err + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, caPrivateKey, nil } -func generateCertChain() error { - cert := &x509.Certificate{ +func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ SerialNumber: big.NewInt(1), DNSNames: []string{"localhost"}, NotBefore: time.Now(), @@ -138,36 +160,86 @@ func generateCertChain() error { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } - var err error - leafPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, nil, err } - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &leafPrivateKey.PublicKey, caPrivateKey) + certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey) if err != nil { - return err + return nil, nil, err } - leafCert, err = x509.ParseCertificate(certBytes) - return err + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, err + } + return cert, privKey, nil } -func getTLSConfig() *tls.Config { +// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain. +// The Root CA used is the same as for the config returned from getTLSConfig(). +func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) { + const chainLen = 7 + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + lastCA := ca + lastCAPrivKey := caPrivateKey + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + certs := make([]*x509.Certificate, chainLen) + for i := 0; i < chainLen; i++ { + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey) + if err != nil { + return nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, err + } + certs[i] = ca + lastCA = ca + lastCAPrivKey = privKey + } + leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey) + if err != nil { + return nil, err + } + + rawCerts := make([][]byte, chainLen+1) + for i, cert := range certs { + rawCerts[chainLen-i] = cert.Raw + } + rawCerts[0] = leafCert.Raw + return &tls.Config{ Certificates: []tls.Certificate{tls.Certificate{ - Certificate: [][]byte{leafCert.Raw}, + Certificate: rawCerts, PrivateKey: leafPrivateKey, }}, NextProtos: []string{alpn}, - } + }, nil +} + +func getTLSConfig() *tls.Config { + return tlsConfig.Clone() +} + +func getTLSConfigWithLongCertChain() *tls.Config { + return tlsConfigLongChain.Clone() } func getTLSClientConfig() *tls.Config { - root := x509.NewCertPool() - root.AddCert(ca) - return &tls.Config{ - RootCAs: root, - NextProtos: []string{alpn}, - } + return tlsClientConfig.Clone() } func getQuicConfigForClient(conf *quic.Config) *quic.Config {