forked from quic-go/quic-go
make the connection ID length configurable
This commit is contained in:
29
client.go
29
client.go
@@ -74,6 +74,7 @@ func DialAddrContext(
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, false)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -115,8 +116,12 @@ func DialContext(
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, true)
|
||||
multiplexer := getClientMultiplexer()
|
||||
manager := multiplexer.AddConn(pconn)
|
||||
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -138,9 +143,6 @@ func newClient(
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
) (*client, error) {
|
||||
clientConfig := populateClientConfig(config)
|
||||
version := clientConfig.Versions[0]
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
@@ -169,8 +171,8 @@ func newClient(
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: version,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
@@ -180,7 +182,7 @@ func newClient(
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config) *Config {
|
||||
func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
@@ -218,12 +220,17 @@ func populateClientConfig(config *Config) *Config {
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDLen := config.ConnectionIDLength
|
||||
if connIDLen == 0 && onPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
@@ -233,7 +240,11 @@ func populateClientConfig(config *Config) *Config {
|
||||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC)
|
||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
||||
if c.version.UsesTLS() {
|
||||
connIDLen = c.config.ConnectionIDLength
|
||||
}
|
||||
srcConnID, err := generateConnectionID(connIDLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -370,7 +381,7 @@ func (c *client) handleRead(remoteAddr net.Addr, packet []byte) {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
c.logger.Errorf("error parsing invariant header: %s", err)
|
||||
|
||||
Reference in New Issue
Block a user