forked from quic-go/quic-go
fix version negotiation
This commit is contained in:
94
client.go
94
client.go
@@ -17,18 +17,18 @@ import (
|
||||
)
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
listenErr error
|
||||
mutex sync.Mutex
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
@@ -100,23 +100,21 @@ func DialNonFWSecure(
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
@@ -132,8 +130,7 @@ func Dial(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sess.WaitUntilHandshakeComplete()
|
||||
if err != nil {
|
||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
@@ -181,11 +178,37 @@ func populateClientConfig(config *Config) *Config {
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
if err := c.createNewSession(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
for {
|
||||
runErr = c.session.run()
|
||||
if runErr == errCloseSessionForNewVersion {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
select {
|
||||
case <-c.errorChan:
|
||||
return c.listenErr
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case <-c.versionNegotiationChan:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
@@ -263,6 +286,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
if hdr.VersionFlag {
|
||||
@@ -298,7 +322,6 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
|
||||
|
||||
// switch to negotiated version
|
||||
c.version = newVersion
|
||||
c.versionNegotiated = true
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
@@ -306,7 +329,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
// create a new session and close the old one
|
||||
// the new session must be created first to update client member variables
|
||||
oldSession := c.session
|
||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(hdr.SupportedVersions)
|
||||
}
|
||||
|
||||
@@ -321,21 +347,5 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
err := c.session.run()
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return
|
||||
}
|
||||
c.listenErr = err
|
||||
close(c.errorChan)
|
||||
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user