fix version negotiation

This commit is contained in:
Marten Seemann
2017-09-15 16:21:29 +07:00
parent 9029d6e7d7
commit 7003450d2b
2 changed files with 140 additions and 64 deletions

View File

@@ -17,18 +17,18 @@ import (
)
type client struct {
mutex sync.Mutex
listenErr error
mutex sync.Mutex
conn connection
hostname string
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet
tlsConf *tls.Config
config *Config
connectionID protocol.ConnectionID
version protocol.VersionNumber
@@ -100,23 +100,21 @@ func DialNonFWSecure(
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection()
if err := c.establishSecureConnection(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@@ -132,8 +130,7 @@ func Dial(
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
if err := sess.WaitUntilHandshakeComplete(); err != nil {
return nil, err
}
return sess, nil
@@ -181,11 +178,37 @@ func populateClientConfig(config *Config) *Config {
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
if err := c.createNewSession(nil); err != nil {
return err
}
go c.listen()
var runErr error
errorChan := make(chan struct{})
go func() {
// session.run() returns as soon as the session is closed
for {
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
continue
}
break
}
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-c.errorChan:
return c.listenErr
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-errorChan:
return runErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
@@ -263,6 +286,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
if hdr.VersionFlag {
@@ -298,7 +322,6 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
// switch to negotiated version
c.version = newVersion
c.versionNegotiated = true
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
@@ -306,7 +329,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
// create a new session and close the old one
// the new session must be created first to update client member variables
oldSession := c.session
defer oldSession.Close(errCloseSessionForNewVersion)
return c.createNewSession(hdr.SupportedVersions)
}
@@ -321,21 +347,5 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.config,
negotiatedVersions,
)
if err != nil {
return err
}
go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.listenErr = err
close(c.errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil
return err
}