From 72c2f9464caec3f46f351ab7e27a6dc2c1f89941 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 9 Feb 2020 19:23:57 +0800 Subject: [PATCH] add a Clone() function to the Config --- config.go | 76 +++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 67 ---------------------------------------- 3 files changed, 159 insertions(+), 67 deletions(-) create mode 100644 config.go create mode 100644 config_test.go diff --git a/config.go b/config.go new file mode 100644 index 000000000..5d05f3345 --- /dev/null +++ b/config.go @@ -0,0 +1,76 @@ +package quic + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// Clone clones a Config +func (c *Config) Clone() *Config { + copy := *c + return © +} + +// populateServerConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateServerConfig(config *Config) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + if config.AcceptToken == nil { + config.AcceptToken = defaultAcceptToken + } + return config +} + +func populateConfig(config *Config) *Config { + if config == nil { + config = &Config{} + } + versions := config.Versions + if len(versions) == 0 { + versions = protocol.SupportedVersions + } + handshakeTimeout := protocol.DefaultHandshakeTimeout + if config.HandshakeTimeout != 0 { + handshakeTimeout = config.HandshakeTimeout + } + idleTimeout := protocol.DefaultIdleTimeout + if config.MaxIdleTimeout != 0 { + idleTimeout = config.MaxIdleTimeout + } + maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow + if maxReceiveStreamFlowControlWindow == 0 { + maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow + } + maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow + if maxReceiveConnectionFlowControlWindow == 0 { + maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow + } + maxIncomingStreams := config.MaxIncomingStreams + if maxIncomingStreams == 0 { + maxIncomingStreams = protocol.DefaultMaxIncomingStreams + } else if maxIncomingStreams < 0 { + maxIncomingStreams = 0 + } + maxIncomingUniStreams := config.MaxIncomingUniStreams + if maxIncomingUniStreams == 0 { + maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams + } else if maxIncomingUniStreams < 0 { + maxIncomingUniStreams = 0 + } + + return &Config{ + Versions: versions, + HandshakeTimeout: handshakeTimeout, + MaxIdleTimeout: idleTimeout, + AcceptToken: config.AcceptToken, + KeepAlive: config.KeepAlive, + MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, + MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + ConnectionIDLength: config.ConnectionIDLength, + StatelessResetKey: config.StatelessResetKey, + TokenStore: config.TokenStore, + QuicTracer: config.QuicTracer, + } +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 000000000..239c1cb8b --- /dev/null +++ b/config_test.go @@ -0,0 +1,83 @@ +package quic + +import ( + "fmt" + "net" + "reflect" + "time" + + "github.com/lucas-clemente/quic-go/quictrace" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Config", func() { + It("clones function fields", func() { + var called bool + c1 := &Config{AcceptToken: func(_ net.Addr, _ *Token) bool { called = true; return true }} + c2 := c1.Clone() + c2.AcceptToken(&net.UDPAddr{}, &Token{}) + Expect(called).To(BeTrue()) + }) + + It("clones non-function fields", func() { + c1 := &Config{} + v := reflect.ValueOf(c1).Elem() + + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f := v.Field(i) + if !f.CanSet() { + // unexported field; not cloned. + continue + } + + switch fn := typ.Field(i).Name; fn { + case "AcceptToken": + // Can't compare functions. + case "Versions": + f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) + case "ConnectionIDLength": + f.Set(reflect.ValueOf(8)) + case "HandshakeTimeout": + f.Set(reflect.ValueOf(time.Second)) + case "MaxIdleTimeout": + f.Set(reflect.ValueOf(time.Hour)) + case "TokenStore": + f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) + case "MaxReceiveStreamFlowControlWindow": + f.Set(reflect.ValueOf(uint64(9))) + case "MaxReceiveConnectionFlowControlWindow": + f.Set(reflect.ValueOf(uint64(10))) + case "MaxIncomingStreams": + f.Set(reflect.ValueOf(11)) + case "MaxIncomingUniStreams": + f.Set(reflect.ValueOf(12)) + case "StatelessResetKey": + f.Set(reflect.ValueOf([]byte{1, 2, 3, 4})) + case "KeepAlive": + f.Set(reflect.ValueOf(true)) + case "QuicTracer": + f.Set(reflect.ValueOf(quictrace.NewTracer())) + default: + Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) + } + } + + Expect(c1.Clone()).To(Equal(c1)) + }) + + It("returns a copy", func() { + c1 := &Config{ + MaxIncomingStreams: 100, + AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, + } + c2 := c1.Clone() + c2.MaxIncomingStreams = 200 + c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + + Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) + Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + }) +}) diff --git a/server.go b/server.go index 01d851da8..08dd035dc 100644 --- a/server.go +++ b/server.go @@ -224,73 +224,6 @@ var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { return sourceAddr == token.RemoteAddr } -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } - if config.AcceptToken == nil { - config.AcceptToken = defaultAcceptToken - } - return config -} - -func populateConfig(config *Config) *Config { - if config == nil { - config = &Config{} - } - versions := config.Versions - if len(versions) == 0 { - versions = protocol.SupportedVersions - } - handshakeTimeout := protocol.DefaultHandshakeTimeout - if config.HandshakeTimeout != 0 { - handshakeTimeout = config.HandshakeTimeout - } - idleTimeout := protocol.DefaultIdleTimeout - if config.MaxIdleTimeout != 0 { - idleTimeout = config.MaxIdleTimeout - } - maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow - if maxReceiveStreamFlowControlWindow == 0 { - maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow - } - maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow - if maxReceiveConnectionFlowControlWindow == 0 { - maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow - } - maxIncomingStreams := config.MaxIncomingStreams - if maxIncomingStreams == 0 { - maxIncomingStreams = protocol.DefaultMaxIncomingStreams - } else if maxIncomingStreams < 0 { - maxIncomingStreams = 0 - } - maxIncomingUniStreams := config.MaxIncomingUniStreams - if maxIncomingUniStreams == 0 { - maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams - } else if maxIncomingUniStreams < 0 { - maxIncomingUniStreams = 0 - } - - return &Config{ - Versions: versions, - HandshakeTimeout: handshakeTimeout, - MaxIdleTimeout: idleTimeout, - AcceptToken: config.AcceptToken, - KeepAlive: config.KeepAlive, - MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, - MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: config.ConnectionIDLength, - StatelessResetKey: config.StatelessResetKey, - TokenStore: config.TokenStore, - QuicTracer: config.QuicTracer, - } -} - // Accept returns sessions that already completed the handshake. // It is only valid if acceptEarlySessions is false. func (s *baseServer) Accept(ctx context.Context) (Session, error) {