From f02dc92a32a1b590ceb52b5ded1fdd42aca66c55 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jul 2018 16:03:33 +0700 Subject: [PATCH] make the connection ID length configurable --- Changelog.md | 4 +++ client.go | 29 +++++++++++----- client_multiplexer.go | 40 +++++++++++++--------- client_multiplexer_test.go | 25 +++++++++----- client_test.go | 46 +++++++++++++++++--------- interface.go | 7 ++++ internal/protocol/server_parameters.go | 4 +++ mock_multiplexer_test.go | 11 +++--- server.go | 7 +++- server_test.go | 8 +++++ server_tls.go | 4 +-- session_test.go | 2 +- 12 files changed, 131 insertions(+), 56 deletions(-) diff --git a/Changelog.md b/Changelog.md index 7f22ae3ad..490541375 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,9 @@ # Changelog +## v0.9.0 (unreleased) + +- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC). + ## v0.8.0 (2018-06-26) - Add support for unidirectional streams (for IETF QUIC). diff --git a/client.go b/client.go index 1a3d9286a..ee2f6a3cd 100644 --- a/client.go +++ b/client.go @@ -74,6 +74,7 @@ func DialAddrContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + config = populateClientConfig(config, false) udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -115,8 +116,12 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + config = populateClientConfig(config, true) multiplexer := getClientMultiplexer() - manager := multiplexer.AddConn(pconn) + manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength) + if err != nil { + return nil, err + } c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove) if err != nil { return nil, err @@ -138,9 +143,6 @@ func newClient( host string, closeCallback func(protocol.ConnectionID), ) (*client, error) { - clientConfig := populateClientConfig(config) - version := clientConfig.Versions[0] - var hostname string if tlsConf != nil { hostname = tlsConf.ServerName @@ -169,8 +171,8 @@ func newClient( conn: &conn{pconn: pconn, currentAddr: remoteAddr}, hostname: hostname, tlsConf: tlsConf, - config: clientConfig, - version: version, + config: config, + version: config.Versions[0], handshakeChan: make(chan struct{}), closeCallback: onClose, logger: utils.DefaultLogger.WithPrefix("client"), @@ -180,7 +182,7 @@ func newClient( // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config) *Config { +func populateClientConfig(config *Config, onPacketConn bool) *Config { if config == nil { config = &Config{} } @@ -218,12 +220,17 @@ func populateClientConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 && onPacketConn { + connIDLen = protocol.DefaultConnectionIDLength + } return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, IdleTimeout: idleTimeout, RequestConnectionIDOmission: config.RequestConnectionIDOmission, + ConnectionIDLength: connIDLen, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, @@ -233,7 +240,11 @@ func populateClientConfig(config *Config) *Config { } func (c *client) generateConnectionIDs() error { - srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC) + connIDLen := protocol.ConnectionIDLenGQUIC + if c.version.UsesTLS() { + connIDLen = c.config.ConnectionIDLength + } + srcConnID, err := generateConnectionID(connIDLen) if err != nil { return err } @@ -370,7 +381,7 @@ func (c *client) handleRead(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength) // drop the packet if we can't parse the header if err != nil { c.logger.Errorf("error parsing invariant header: %s", err) diff --git a/client_multiplexer.go b/client_multiplexer.go index 1b9a28824..b53c315f7 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "fmt" "net" "strings" "sync" @@ -19,16 +20,21 @@ var ( ) type multiplexer interface { - AddConn(net.PacketConn) packetHandlerManager + AddConn(net.PacketConn, int) (packetHandlerManager, error) AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error } +type connManager struct { + connIDLen int + manager packetHandlerManager +} + // The clientMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the session handler. type clientMultiplexer struct { mutex sync.Mutex - conns map[net.PacketConn]packetHandlerManager + conns map[net.PacketConn]connManager newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests logger utils.Logger @@ -39,7 +45,7 @@ var _ multiplexer = &clientMultiplexer{} func getClientMultiplexer() multiplexer { clientMuxerOnce.Do(func() { clientMuxer = &clientMultiplexer{ - conns: make(map[net.PacketConn]packetHandlerManager), + conns: make(map[net.PacketConn]connManager), logger: utils.DefaultLogger.WithPrefix("client muxer"), newPacketHandlerManager: newPacketHandlerMap, } @@ -47,30 +53,34 @@ func getClientMultiplexer() multiplexer { return clientMuxer } -func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager { +func (m *clientMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() - sessions, ok := m.conns[c] + p, ok := m.conns[c] if !ok { - sessions = m.newPacketHandlerManager() - m.conns[c] = sessions + manager := m.newPacketHandlerManager() + p = connManager{connIDLen: connIDLen, manager: manager} + m.conns[c] = p // If we didn't know this packet conn before, listen for incoming packets // and dispatch them to the right sessions. - go m.listen(c, sessions) + go m.listen(c, p) } - return sessions + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + return p.manager, nil } func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error { - sessions, ok := m.conns[c] + p, ok := m.conns[c] if !ok { return errors.New("unknown packet conn %s") } - sessions.Add(connID, handler) + p.manager.Add(connID, handler) return nil } -func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) { +func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) { for { data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] @@ -79,7 +89,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag n, addr, err := c.ReadFrom(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - sessions.Close(err) + p.manager.Close(err) } return } @@ -87,13 +97,13 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag rcvTime := time.Now() r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen) // drop the packet if we can't parse the header if err != nil { m.logger.Debugf("error parsing invariant header from %s: %s", addr, err) continue } - client, ok := sessions.Get(iHdr.DestConnectionID) + client, ok := p.manager.Get(iHdr.DestConnectionID) if !ok { m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) continue diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index 87109572e..b5f1d4051 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -29,11 +29,12 @@ var _ = Describe("Client Multiplexer", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} packetHandler := NewMockQuicSession(mockCtrl) handledPacket := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID)) close(handledPacket) }) packetHandler.EXPECT().GetVersion() - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) err := getClientMultiplexer().AddHandler(conn, connID, packetHandler) Expect(err).ToNot(HaveOccurred()) conn.dataToRead <- getPacket(connID) @@ -43,6 +44,14 @@ var _ = Describe("Client Multiplexer", func() { close(conn.dataToRead) }) + It("errors when adding an existing conn with a different connection ID length", func() { + conn := newMockPacketConn() + _, err := getClientMultiplexer().AddConn(conn, 5) + Expect(err).ToNot(HaveOccurred()) + _, err = getClientMultiplexer().AddConn(conn, 6) + Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) + }) + It("errors when adding a handler for an unknown conn", func() { conn := newMockPacketConn() err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl)) @@ -67,7 +76,7 @@ var _ = Describe("Client Multiplexer", func() { close(handledPacket2) }) packetHandler2.EXPECT().GetVersion() - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, connID1.Len()) Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed()) Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed()) @@ -84,10 +93,10 @@ var _ = Describe("Client Multiplexer", func() { It("drops unparseable packets", func() { conn := newMockPacketConn() - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} conn.dataToRead <- []byte("invalid header") packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 7) Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() @@ -106,7 +115,7 @@ var _ = Describe("Client Multiplexer", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} done := make(chan struct{}) manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) conn.dataToRead <- getPacket(connID) Eventually(done).Should(BeClosed()) // makes the listen go routine return @@ -118,7 +127,7 @@ var _ = Describe("Client Multiplexer", func() { conn := newMockPacketConn() conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet // makes the listen go routine return @@ -135,7 +144,7 @@ var _ = Describe("Client Multiplexer", func() { packetHandler.EXPECT().Close(testErr).Do(func(error) { close(done) }) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/client_test.go b/client_test.go index d0856ab69..f653efec1 100644 --- a/client_test.go +++ b/client_test.go @@ -147,7 +147,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) run := make(chan struct{}) @@ -176,7 +176,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("early handshake error") @@ -203,7 +203,7 @@ var _ = Describe("Client", func() { It("closes the session when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) sessionRunning := make(chan struct{}) @@ -243,7 +243,7 @@ var _ = Describe("Client", func() { It("removes closed sessions from the multiplexer", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Remove(connID) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) var runner sessionRunner @@ -279,18 +279,20 @@ var _ = Describe("Client", func() { RequestConnectionIDOmission: true, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, + ConnectionIDLength: 13, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) Expect(c.IdleTimeout).To(Equal(42 * time.Hour)) Expect(c.RequestConnectionIDOmission).To(BeTrue()) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) + Expect(c.ConnectionIDLength).To(Equal(13)) }) It("errors when the Config contains an invalid version", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) @@ -302,7 +304,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: -1, MaxIncomingUniStreams: 4321, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.MaxIncomingStreams).To(BeZero()) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) }) @@ -312,13 +314,25 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: -1, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) + It("uses 0-byte connection IDs when dialing an address", func() { + config := &Config{} + c := populateClientConfig(config, false) + Expect(c.ConnectionIDLength).To(BeZero()) + }) + + It("doesn't use 0-byte connection IDs when dialing an address", func() { + config := &Config{} + c := populateClientConfig(config, true) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) + It("fills in default values if options are not set in the Config", func() { - c := populateClientConfig(&Config{}) + c := populateClientConfig(&Config{}, false) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout)) Expect(c.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -329,7 +343,7 @@ var _ = Describe("Client", func() { Context("gQUIC", func() { It("errors if it can't create a session", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("error creating session") @@ -355,7 +369,7 @@ var _ = Describe("Client", func() { Context("IETF QUIC", func() { It("creates new TLS sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} @@ -411,7 +425,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs during version negotiation", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("early handshake error") @@ -568,6 +582,7 @@ var _ = Describe("Client", func() { }) It("drops version negotiation packets that contain the offered version", func() { + cl.config = &Config{} ver := cl.version cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) @@ -581,6 +596,7 @@ var _ = Describe("Client", func() { }) It("ignores packets with an invalid public header", func() { + cl.config = &Config{} cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls cl.handleRead(addr, []byte("invalid packet")) }) @@ -682,7 +698,7 @@ var _ = Describe("Client", func() { It("creates new gQUIC sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: protocol.SupportedVersions} @@ -723,7 +739,7 @@ var _ = Describe("Client", func() { It("creates a new session when the server performs a retry", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} @@ -757,7 +773,7 @@ var _ = Describe("Client", func() { It("only accepts one Retry packet", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} diff --git a/interface.go b/interface.go index 0fdb534ec..b1f19e523 100644 --- a/interface.go +++ b/interface.go @@ -165,6 +165,13 @@ type Config struct { // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // Currently only valid for the client. RequestConnectionIDOmission bool + // The length of the connection ID in bytes. Only valid for IETF QUIC. + // It can be 0, or any value between 4 and 18. + // If not set, the interpretation depends on where the Config is used: + // If used for dialing an address, a 0 byte connection ID will be used. + // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. + // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. + ConnectionIDLength int // HandshakeTimeout is the maximum duration that the cryptographic handshake may take. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 10 seconds. diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index c8696fdff..aa92c8223 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -145,3 +145,7 @@ const MaxAckFrameSize ByteCount = 1000 // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. const MinPacingDelay time.Duration = 100 * time.Microsecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index d1e34a787..aa2f33d87 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -36,15 +36,16 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn) packetHandlerManager { - ret := m.ctrl.Call(m, "AddConn", arg0) +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerManager, error) { + ret := m.ctrl.Call(m, "AddConn", arg0, arg1) ret0, _ := ret[0].(packetHandlerManager) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // AddConn indicates an expected call of AddConn -func (mr *MockMultiplexerMockRecorder) AddConn(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0) +func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1) } // AddHandler mocks base method diff --git a/server.go b/server.go index fe6158fb0..8b560c998 100644 --- a/server.go +++ b/server.go @@ -241,6 +241,10 @@ func populateServerConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 { + connIDLen = protocol.DefaultConnectionIDLength + } return &Config{ Versions: versions, @@ -252,6 +256,7 @@ func populateServerConfig(config *Config) *Config { MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, + ConnectionIDLength: connIDLen, } } @@ -304,7 +309,7 @@ func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } diff --git a/server_test.go b/server_test.go index 2e10e7d7d..00fe8e526 100644 --- a/server_test.go +++ b/server_test.go @@ -48,6 +48,7 @@ var _ = Describe("Server", func() { RequestConnectionIDOmission: true, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, + ConnectionIDLength: 12, } c := populateServerConfig(config) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) @@ -55,6 +56,7 @@ var _ = Describe("Server", func() { Expect(c.RequestConnectionIDOmission).To(BeFalse()) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) + Expect(c.ConnectionIDLength).To(Equal(12)) }) It("disables bidirectional streams", func() { @@ -76,6 +78,12 @@ var _ = Describe("Server", func() { Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) + + It("doesn't use 0-byte connection IDs", func() { + config := &Config{} + c := populateClientConfig(config, true) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) }) Context("with mock session", func() { diff --git a/server_tls.go b/server_tls.go index 5d303b0c8..4089af04d 100644 --- a/server_tls.go +++ b/server_tls.go @@ -194,7 +194,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, StreamID: version.CryptoStreamID(), Data: bc.GetDataForWriting(), } - srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err } @@ -228,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) } params := <-paramsChan - connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err } diff --git a/session_test.go b/session_test.go index 12ad17eab..f6225692e 100644 --- a/session_test.go +++ b/session_test.go @@ -1769,7 +1769,7 @@ var _ = Describe("Client Session", func() { protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, nil, - populateClientConfig(&Config{}), + populateClientConfig(&Config{}, false), protocol.VersionWhatever, nil, utils.DefaultLogger,