diff --git a/internal/qtls/cipher_suite_test.go b/internal/qtls/cipher_suite_test.go index 716d8217..df84ffcb 100644 --- a/internal/qtls/cipher_suite_test.go +++ b/internal/qtls/cipher_suite_test.go @@ -4,47 +4,46 @@ import ( "crypto/tls" "fmt" "net" + "testing" "github.com/quic-go/quic-go/internal/testdata" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Setting the Cipher Suite", func() { - for _, cs := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_AES_256_GCM_SHA384} { - cs := cs +func TestCipherSuiteSelection(t *testing.T) { + t.Run("TLS_AES_128_GCM_SHA256", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_AES_128_GCM_SHA256) }) + t.Run("TLS_CHACHA20_POLY1305_SHA256", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_CHACHA20_POLY1305_SHA256) }) + t.Run("TLS_AES_256_GCM_SHA384", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_AES_256_GCM_SHA384) }) +} - It(fmt.Sprintf("selects %s", tls.CipherSuiteName(cs)), func() { - reset := SetCipherSuite(cs) - defer reset() +func testCipherSuiteSelection(t *testing.T, cs uint16) { + reset := SetCipherSuite(cs) + defer reset() - ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() + ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) + require.NoError(t, err) + defer ln.Close() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Read(make([]byte, 10)) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.(*tls.Conn).ConnectionState().CipherSuite).To(Equal(cs)) - }() + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + require.NoError(t, err) + _, err = conn.Read(make([]byte, 10)) + require.NoError(t, err) + require.Equal(t, cs, conn.(*tls.Conn).ConnectionState().CipherSuite) + }() - conn, err := tls.Dial( - "tcp4", - fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), - &tls.Config{RootCAs: testdata.GetRootCA()}, - ) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().CipherSuite).To(Equal(cs)) - Expect(conn.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - } -}) + conn, err := tls.Dial( + "tcp4", + fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), + &tls.Config{RootCAs: testdata.GetRootCA()}, + ) + require.NoError(t, err) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, cs, conn.ConnectionState().CipherSuite) + require.NoError(t, conn.Close()) + <-done +} diff --git a/internal/qtls/client_session_cache_test.go b/internal/qtls/client_session_cache_test.go index d3c747bf..97d069a5 100644 --- a/internal/qtls/client_session_cache_test.go +++ b/internal/qtls/client_session_cache_test.go @@ -4,81 +4,81 @@ import ( "crypto/tls" "fmt" "net" + "testing" "github.com/quic-go/quic-go/internal/testdata" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Client Session Cache", func() { - It("adds data to and restores data from a session ticket", func() { - ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) +func TestClientSessionCacheAddAndRestoreData(t *testing.T) { + ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) + require.NoError(t, err) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) + done := make(chan struct{}) + go func() { + defer close(done) - for { - conn, err := ln.Accept() - if err != nil { - return - } - _, err = conn.Read(make([]byte, 10)) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) + for { + conn, err := ln.Accept() + if err != nil { + return } - }() - - restored := make(chan []byte, 1) - clientConf := &tls.Config{ - RootCAs: testdata.GetRootCA(), - ClientSessionCache: &clientSessionCache{ - wrapped: tls.NewLRUClientSessionCache(10), - getData: func(bool) []byte { return []byte("session") }, - setData: func(data []byte, earlyData bool) bool { - Expect(earlyData).To(BeFalse()) // running on top of TCP, we can only test non-0-RTT here - restored <- data - return true - }, - }, + _, err = conn.Read(make([]byte, 10)) + require.NoError(t, err) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) } - conn, err := tls.Dial( - "tcp4", - fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().DidResume).To(BeFalse()) - Expect(restored).To(HaveLen(0)) - _, err = conn.Read(make([]byte, 10)) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.Close()).To(Succeed()) + }() - // make sure the cache can deal with nonsensical inputs - clientConf.ClientSessionCache.Put("foo", nil) - clientConf.ClientSessionCache.Put("bar", &tls.ClientSessionState{}) + restored := make(chan []byte, 1) + clientConf := &tls.Config{ + RootCAs: testdata.GetRootCA(), + ClientSessionCache: &clientSessionCache{ + wrapped: tls.NewLRUClientSessionCache(10), + getData: func(bool) []byte { return []byte("session") }, + setData: func(data []byte, earlyData bool) bool { + require.False(t, earlyData) // running on top of TCP, we can only test non-0-RTT here + restored <- data + return true + }, + }, + } + conn, err := tls.Dial( + "tcp4", + fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), + clientConf, + ) + require.NoError(t, err) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) + require.False(t, conn.ConnectionState().DidResume) + require.Len(t, restored, 0) + _, err = conn.Read(make([]byte, 10)) + require.NoError(t, err) + require.NoError(t, conn.Close()) - conn, err = tls.Dial( - "tcp4", - fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().DidResume).To(BeTrue()) - var restoredData []byte - Expect(restored).To(Receive(&restoredData)) - Expect(restoredData).To(Equal([]byte("session"))) - Expect(conn.Close()).To(Succeed()) + // make sure the cache can deal with nonsensical inputs + clientConf.ClientSessionCache.Put("foo", nil) + clientConf.ClientSessionCache.Put("bar", &tls.ClientSessionState{}) - Expect(ln.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) -}) + conn, err = tls.Dial( + "tcp4", + fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), + clientConf, + ) + require.NoError(t, err) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) + require.True(t, conn.ConnectionState().DidResume) + var restoredData []byte + select { + case restoredData = <-restored: + default: + t.Fatal("no data restored") + } + require.Equal(t, []byte("session"), restoredData) + require.NoError(t, conn.Close()) + + require.NoError(t, ln.Close()) + <-done +} diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go deleted file mode 100644 index bde81e6c..00000000 --- a/internal/qtls/qtls_suite_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package qtls - -import ( - "testing" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "go.uber.org/mock/gomock" -) - -func TestQTLS(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "qtls Suite") -} - -var mockCtrl *gomock.Controller - -var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) -}) - -var _ = AfterEach(func() { - mockCtrl.Finish() -}) diff --git a/internal/qtls/qtls_test.go b/internal/qtls/qtls_test.go index b041af74..b201fc8e 100644 --- a/internal/qtls/qtls_test.go +++ b/internal/qtls/qtls_test.go @@ -4,125 +4,134 @@ import ( "crypto/tls" "net" "reflect" + "testing" "github.com/quic-go/quic-go/internal/protocol" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("interface go crypto/tls", func() { - It("converts to tls.EncryptionLevel", func() { - Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(tls.QUICEncryptionLevelInitial)) - Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(tls.QUICEncryptionLevelHandshake)) - Expect(ToTLSEncryptionLevel(protocol.Encryption1RTT)).To(Equal(tls.QUICEncryptionLevelApplication)) - Expect(ToTLSEncryptionLevel(protocol.Encryption0RTT)).To(Equal(tls.QUICEncryptionLevelEarly)) - }) +func TestEncryptionLevelConversion(t *testing.T) { + testCases := []struct { + quicLevel protocol.EncryptionLevel + tlsLevel tls.QUICEncryptionLevel + }{ + {protocol.EncryptionInitial, tls.QUICEncryptionLevelInitial}, + {protocol.EncryptionHandshake, tls.QUICEncryptionLevelHandshake}, + {protocol.Encryption1RTT, tls.QUICEncryptionLevelApplication}, + {protocol.Encryption0RTT, tls.QUICEncryptionLevelEarly}, + } - It("converts from tls.EncryptionLevel", func() { - Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelInitial)).To(Equal(protocol.EncryptionInitial)) - Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelHandshake)).To(Equal(protocol.EncryptionHandshake)) - Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelApplication)).To(Equal(protocol.Encryption1RTT)) - Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelEarly)).To(Equal(protocol.Encryption0RTT)) - }) - - Context("setting up a tls.Config for the client", func() { - It("sets up a session cache if there's one present on the config", func() { - csc := tls.NewLRUClientSessionCache(1) - conf := &tls.QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}} - SetupConfigForClient(conf, nil, nil) - Expect(conf.TLSConfig.ClientSessionCache).ToNot(BeNil()) - Expect(conf.TLSConfig.ClientSessionCache).ToNot(Equal(csc)) + for _, tc := range testCases { + t.Run(tc.quicLevel.String(), func(t *testing.T) { + // conversion from QUIC to TLS encryption level + require.Equal(t, tc.tlsLevel, ToTLSEncryptionLevel(tc.quicLevel)) + // conversion from TLS to QUIC encryption level + require.Equal(t, tc.quicLevel, FromTLSEncryptionLevel(tc.tlsLevel)) }) + } +} - It("doesn't set up a session cache if there's none present on the config", func() { - conf := &tls.QUICConfig{TLSConfig: &tls.Config{}} - SetupConfigForClient(conf, nil, nil) - Expect(conf.TLSConfig.ClientSessionCache).To(BeNil()) - }) - }) +func TestSetupSessionCache(t *testing.T) { + // Test with a session cache present + csc := tls.NewLRUClientSessionCache(1) + confWithCache := &tls.QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}} + SetupConfigForClient(confWithCache, nil, nil) + require.NotNil(t, confWithCache.TLSConfig.ClientSessionCache) + require.NotEqual(t, csc, confWithCache.TLSConfig.ClientSessionCache) - Context("setting up a tls.Config for the server", func() { - var ( - local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} - remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - ) + // Test without a session cache + confWithoutCache := &tls.QUICConfig{TLSConfig: &tls.Config{}} + SetupConfigForClient(confWithoutCache, nil, nil) + require.Nil(t, confWithoutCache.TLSConfig.ClientSessionCache) +} - It("sets the minimum TLS version to TLS 1.3", func() { - orig := &tls.Config{MinVersion: tls.VersionTLS12} - conf := SetupConfigForServer(orig, local, remote, nil, nil) - Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) - // check that the original config wasn't modified - Expect(orig.MinVersion).To(BeEquivalentTo(tls.VersionTLS12)) - }) +func TestMinimumTLSVersion(t *testing.T) { + local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - It("wraps GetCertificate", func() { - var localAddr, remoteAddr net.Addr - tlsConf := &tls.Config{ - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - return &tls.Certificate{}, nil - }, - } - conf := SetupConfigForServer(tlsConf, local, remote, nil, nil) - _, err := conf.GetCertificate(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - }) + orig := &tls.Config{MinVersion: tls.VersionTLS12} + conf := SetupConfigForServer(orig, local, remote, nil, nil) + require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) + // check that the original config wasn't modified + require.EqualValues(t, tls.VersionTLS12, orig.MinVersion) +} - It("wraps GetConfigForClient", func() { - var localAddr, remoteAddr net.Addr - tlsConf := SetupConfigForServer( - &tls.Config{ - GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - return &tls.Config{}, nil - }, - }, - local, - remote, - nil, - nil, - ) - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - Expect(conf).ToNot(BeNil()) - Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) - }) +func TestServerConfigGetCertificate(t *testing.T) { + local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - It("wraps GetConfigForClient, recursively", func() { - var localAddr, remoteAddr net.Addr - tlsConf := &tls.Config{} - var innerConf *tls.Config - getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Certificate{}, nil + }, + } + conf := SetupConfigForServer(tlsConf, local, remote, nil, nil) + _, err := conf.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.Equal(t, local, localAddr) + require.Equal(t, remote, remoteAddr) +} + +func TestServerConfigGetConfigForClient(t *testing.T) { + local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + + var localAddr, remoteAddr net.Addr + tlsConf := SetupConfigForServer( + &tls.Config{ + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { localAddr = info.Conn.LocalAddr() remoteAddr = info.Conn.RemoteAddr() - return &tls.Certificate{}, nil - } - tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - innerConf = tlsConf.Clone() - // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config - innerConf.MaxVersion = tls.VersionTLS12 - innerConf.GetCertificate = getCert - return innerConf, nil - } - tlsConf = SetupConfigForServer(tlsConf, local, remote, nil, nil) - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf).ToNot(BeNil()) - Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) - _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(localAddr).To(Equal(local)) - Expect(remoteAddr).To(Equal(remote)) - // make sure that the tls.Config returned by GetConfigForClient isn't modified - Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) - Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) - }) - }) -}) + return &tls.Config{}, nil + }, + }, + local, + remote, + nil, + nil, + ) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.Equal(t, local, localAddr) + require.Equal(t, remote, remoteAddr) + require.NotNil(t, conf) + require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) +} + +func TestServerConfigGetConfigForClientRecursively(t *testing.T) { + local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{} + var innerConf *tls.Config + //nolint:unparam + getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Certificate{}, nil + } + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + innerConf = tlsConf.Clone() + // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config + innerConf.MaxVersion = tls.VersionTLS12 + innerConf.GetCertificate = getCert + return innerConf, nil + } + tlsConf = SetupConfigForServer(tlsConf, local, remote, nil, nil) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, conf) + require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) + _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.Equal(t, local, localAddr) + require.Equal(t, remote, remoteAddr) + // make sure that the tls.Config returned by GetConfigForClient isn't modified + require.True(t, reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()) + require.EqualValues(t, tls.VersionTLS12, innerConf.MaxVersion) +}