From 0264fd6f4f082ca11d8240e64b98e83c4edecae0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 23 Dec 2024 19:04:26 +0800 Subject: [PATCH] migrate the config tests away from Ginkgo (#4790) --- config_test.go | 328 ++++++++++++++++++++++++------------------------- 1 file changed, 161 insertions(+), 167 deletions(-) diff --git a/config_test.go b/config_test.go index 6008a5273..ff4dd042d 100644 --- a/config_test.go +++ b/config_test.go @@ -3,190 +3,184 @@ package quic import ( "context" "errors" - "fmt" "reflect" + "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/quicvarint" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Config", func() { - Context("validating", func() { - It("validates a nil config", func() { - Expect(validateConfig(nil)).To(Succeed()) - }) - - It("validates a config with normal values", func() { - conf := populateConfig(&Config{ - MaxIncomingStreams: 5, - MaxStreamReceiveWindow: 10, - }) - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5)) - Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10)) - }) - - It("clips too large values for the stream limits", func() { - conf := &Config{ - MaxIncomingStreams: 1<<60 + 1, - MaxIncomingUniStreams: 1<<60 + 2, - } - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60))) - Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60))) - }) - - It("clips too large values for the flow control windows", func() { - conf := &Config{ - MaxStreamReceiveWindow: quicvarint.Max + 1, - MaxConnectionReceiveWindow: quicvarint.Max + 2, - } - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) - Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) - }) - - It("increases too small packet sizes", func() { - conf := &Config{InitialPacketSize: 10} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.InitialPacketSize).To(BeEquivalentTo(1200)) - }) - - It("clips too large packet sizes", func() { - conf := &Config{InitialPacketSize: protocol.MaxPacketBufferSize + 1} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.InitialPacketSize).To(BeEquivalentTo(protocol.MaxPacketBufferSize)) - }) - - It("doesn't modify the InitialPacketSize if it is unset", func() { - conf := &Config{InitialPacketSize: 0} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.InitialPacketSize).To(BeZero()) - }) +func TestConfigValidation(t *testing.T) { + t.Run("nil config", func(t *testing.T) { + require.NoError(t, validateConfig(nil)) }) - configWithNonZeroNonFunctionFields := func() *Config { - c := &Config{} - v := reflect.ValueOf(c).Elem() + t.Run("config with a few values set", func(t *testing.T) { + conf := populateConfig(&Config{ + MaxIncomingStreams: 5, + MaxStreamReceiveWindow: 10, + }) + require.NoError(t, validateConfig(conf)) + require.Equal(t, int64(5), conf.MaxIncomingStreams) + require.Equal(t, uint64(10), conf.MaxStreamReceiveWindow) + }) - typ := v.Type() - for i := 0; i < typ.NumField(); i++ { - f := v.Field(i) - if !f.CanSet() { - // unexported field; not cloned. - continue - } - - switch fn := typ.Field(i).Name; fn { - case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": - // Can't compare functions. - case "Versions": - f.Set(reflect.ValueOf([]Version{1, 2, 3})) - case "ConnectionIDLength": - f.Set(reflect.ValueOf(8)) - case "ConnectionIDGenerator": - f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) - case "HandshakeIdleTimeout": - f.Set(reflect.ValueOf(time.Second)) - case "MaxIdleTimeout": - f.Set(reflect.ValueOf(time.Hour)) - case "TokenStore": - f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) - case "InitialStreamReceiveWindow": - f.Set(reflect.ValueOf(uint64(1234))) - case "MaxStreamReceiveWindow": - f.Set(reflect.ValueOf(uint64(9))) - case "InitialConnectionReceiveWindow": - f.Set(reflect.ValueOf(uint64(4321))) - case "MaxConnectionReceiveWindow": - f.Set(reflect.ValueOf(uint64(10))) - case "MaxIncomingStreams": - f.Set(reflect.ValueOf(int64(11))) - case "MaxIncomingUniStreams": - f.Set(reflect.ValueOf(int64(12))) - case "StatelessResetKey": - f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) - case "KeepAlivePeriod": - f.Set(reflect.ValueOf(time.Second)) - case "EnableDatagrams": - f.Set(reflect.ValueOf(true)) - case "DisableVersionNegotiationPackets": - f.Set(reflect.ValueOf(true)) - case "InitialPacketSize": - f.Set(reflect.ValueOf(uint16(1350))) - case "DisablePathMTUDiscovery": - f.Set(reflect.ValueOf(true)) - case "Allow0RTT": - f.Set(reflect.ValueOf(true)) - default: - Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) - } + t.Run("stream limits", func(t *testing.T) { + conf := &Config{ + MaxIncomingStreams: 1<<60 + 1, + MaxIncomingUniStreams: 1<<60 + 2, + } + require.NoError(t, validateConfig(conf)) + require.Equal(t, int64(1<<60), conf.MaxIncomingStreams) + require.Equal(t, int64(1<<60), conf.MaxIncomingUniStreams) + }) + + t.Run("flow control windows", func(t *testing.T) { + conf := &Config{ + MaxStreamReceiveWindow: quicvarint.Max + 1, + MaxConnectionReceiveWindow: quicvarint.Max + 2, + } + require.NoError(t, validateConfig(conf)) + require.Equal(t, uint64(quicvarint.Max), conf.MaxStreamReceiveWindow) + require.Equal(t, uint64(quicvarint.Max), conf.MaxConnectionReceiveWindow) + }) + + t.Run("initial packet size", func(t *testing.T) { + // not set + conf := &Config{InitialPacketSize: 0} + require.NoError(t, validateConfig(conf)) + require.Zero(t, conf.InitialPacketSize) + + // too small + conf = &Config{InitialPacketSize: 10} + require.NoError(t, validateConfig(conf)) + require.Equal(t, uint16(1200), conf.InitialPacketSize) + + // too large + conf = &Config{InitialPacketSize: protocol.MaxPacketBufferSize + 1} + require.NoError(t, validateConfig(conf)) + require.Equal(t, uint16(protocol.MaxPacketBufferSize), conf.InitialPacketSize) + }) +} + +func TestConfigHandshakeIdleTimeout(t *testing.T) { + c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} + require.Equal(t, 11*time.Second, c.handshakeTimeout()) +} + +func configWithNonZeroNonFunctionFields(t *testing.T) *Config { + t.Helper() + c := &Config{} + v := reflect.ValueOf(c).Elem() + + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f := v.Field(i) + if !f.CanSet() { + // unexported field; not cloned. + continue + } + + switch fn := typ.Field(i).Name; fn { + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": + // Can't compare functions. + case "Versions": + f.Set(reflect.ValueOf([]Version{1, 2, 3})) + case "ConnectionIDLength": + f.Set(reflect.ValueOf(8)) + case "ConnectionIDGenerator": + f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) + case "HandshakeIdleTimeout": + f.Set(reflect.ValueOf(time.Second)) + case "MaxIdleTimeout": + f.Set(reflect.ValueOf(time.Hour)) + case "TokenStore": + f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) + case "InitialStreamReceiveWindow": + f.Set(reflect.ValueOf(uint64(1234))) + case "MaxStreamReceiveWindow": + f.Set(reflect.ValueOf(uint64(9))) + case "InitialConnectionReceiveWindow": + f.Set(reflect.ValueOf(uint64(4321))) + case "MaxConnectionReceiveWindow": + f.Set(reflect.ValueOf(uint64(10))) + case "MaxIncomingStreams": + f.Set(reflect.ValueOf(int64(11))) + case "MaxIncomingUniStreams": + f.Set(reflect.ValueOf(int64(12))) + case "StatelessResetKey": + f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) + case "KeepAlivePeriod": + f.Set(reflect.ValueOf(time.Second)) + case "EnableDatagrams": + f.Set(reflect.ValueOf(true)) + case "DisableVersionNegotiationPackets": + f.Set(reflect.ValueOf(true)) + case "InitialPacketSize": + f.Set(reflect.ValueOf(uint16(1350))) + case "DisablePathMTUDiscovery": + f.Set(reflect.ValueOf(true)) + case "Allow0RTT": + f.Set(reflect.ValueOf(true)) + default: + t.Fatalf("all fields must be accounted for, but saw unknown field %q", fn) } - return c } + return c +} - It("uses twice the handshake idle timeouts for the handshake timeout", func() { - c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} - Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) +func TestConfigCloning(t *testing.T) { + t.Run("function fields", func(t *testing.T) { + var calledAllowConnectionWindowIncrease, calledTracer bool + c1 := &Config{ + GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, + AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, + Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { + calledTracer = true + return nil + }, + } + c2 := c1.Clone() + c2.AllowConnectionWindowIncrease(nil, 1234) + require.True(t, calledAllowConnectionWindowIncrease) + _, err := c2.GetConfigForClient(&ClientHelloInfo{}) + require.EqualError(t, err, "nope") + c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) + require.True(t, calledTracer) }) - Context("cloning", func() { - It("clones function fields", func() { - var calledAllowConnectionWindowIncrease, calledTracer bool - c1 := &Config{ - GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, - AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, - Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { - calledTracer = true - return nil - }, - } - c2 := c1.Clone() - c2.AllowConnectionWindowIncrease(nil, 1234) - Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) - _, err := c2.GetConfigForClient(&ClientHelloInfo{}) - Expect(err).To(MatchError("nope")) - c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) - Expect(calledTracer).To(BeTrue()) - }) - - It("clones non-function fields", func() { - c := configWithNonZeroNonFunctionFields() - Expect(c.Clone()).To(Equal(c)) - }) - - It("returns a copy", func() { - c1 := &Config{MaxIncomingStreams: 100} - c2 := c1.Clone() - c2.MaxIncomingStreams = 200 - - Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - }) + t.Run("clones non-function fields", func(t *testing.T) { + c := configWithNonZeroNonFunctionFields(t) + require.Equal(t, c, c.Clone()) }) - Context("populating", func() { - It("copies non-function fields", func() { - c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c)).To(Equal(c)) - }) - - It("populates empty fields with default values", func() { - c := populateConfig(&Config{}) - Expect(c.Versions).To(Equal(protocol.SupportedVersions)) - Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) - Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) - Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) - Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) - Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) - Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) - Expect(c.DisablePathMTUDiscovery).To(BeFalse()) - Expect(c.GetConfigForClient).To(BeNil()) - }) + t.Run("returns a copy", func(t *testing.T) { + c1 := &Config{MaxIncomingStreams: 100} + c2 := c1.Clone() + c2.MaxIncomingStreams = 200 + require.EqualValues(t, 100, c1.MaxIncomingStreams) }) -}) +} + +func TestConfigDefaultValues(t *testing.T) { + // if set, the values should be copied + c := configWithNonZeroNonFunctionFields(t) + require.Equal(t, c, populateConfig(c)) + + // if not set, some fields use default values + c = populateConfig(&Config{}) + require.Equal(t, protocol.SupportedVersions, c.Versions) + require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout) + require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow) + require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow) + require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow) + require.EqualValues(t, protocol.DefaultMaxReceiveConnectionFlowControlWindow, c.MaxConnectionReceiveWindow) + require.EqualValues(t, protocol.DefaultMaxIncomingStreams, c.MaxIncomingStreams) + require.EqualValues(t, protocol.DefaultMaxIncomingUniStreams, c.MaxIncomingUniStreams) + require.False(t, c.DisablePathMTUDiscovery) + require.Nil(t, c.GetConfigForClient) +}