simplify version negotiation in the client

This commit is contained in:
Marten Seemann
2018-05-10 10:30:20 +09:00
parent 407a563c73
commit c98afd6625
2 changed files with 10 additions and 23 deletions

View File

@@ -23,8 +23,7 @@ type client struct {
conn connection conn connection
hostname string hostname string
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version versionNegotiated bool // has the server accepted our version
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
@@ -106,15 +105,14 @@ func Dial(
} }
} }
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
hostname: hostname, hostname: hostname,
tlsConf: tlsConf, tlsConf: tlsConf,
config: clientConfig, config: clientConfig,
version: version, version: version,
versionNegotiationChan: make(chan struct{}), logger: utils.DefaultLogger,
logger: utils.DefaultLogger,
} }
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@@ -256,13 +254,6 @@ func (c *client) establishSecureConnection() error {
} }
}() }()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select { select {
case <-errorChan: case <-errorChan:
return runErr return runErr
@@ -361,7 +352,6 @@ func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
// since it is not a Version Negotiation Packet, this means the server supports the suggested version // since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated { if !c.versionNegotiated {
c.versionNegotiated = true c.versionNegotiated = true
close(c.versionNegotiationChan)
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
@@ -399,7 +389,6 @@ func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData
// since it is not a Version Negotiation Packet, this means the server supports the suggested version // since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated { if !c.versionNegotiated {
c.versionNegotiated = true c.versionNegotiated = true
close(c.versionNegotiationChan)
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{

View File

@@ -60,8 +60,7 @@ var _ = Describe("Client", func() {
session: sess, session: sess,
version: protocol.SupportedVersions[0], version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr}, conn: &conn{pconn: packetConn, currentAddr: addr},
versionNegotiationChan: make(chan struct{}), logger: utils.DefaultLogger,
logger: utils.DefaultLogger,
} }
}) })
@@ -382,7 +381,6 @@ var _ = Describe("Client", func() {
err = cl.handlePacket(nil, b.Bytes()) err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue()) Expect(cl.versionNegotiated).To(BeTrue())
Expect(cl.versionNegotiationChan).To(BeClosed())
}) })
It("changes the version after receiving a version negotiation packet", func() { It("changes the version after receiving a version negotiation packet", func() {