fix version negotiation

This commit is contained in:
Marten Seemann
2017-09-15 16:21:29 +07:00
parent 9029d6e7d7
commit 7003450d2b
2 changed files with 140 additions and 64 deletions

View File

@@ -18,17 +18,17 @@ import (
type client struct {
mutex sync.Mutex
listenErr error
conn connection
hostname string
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
connectionID protocol.ConnectionID
version protocol.VersionNumber
@@ -106,17 +106,15 @@ func DialNonFWSecure(
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection()
if err := c.establishSecureConnection(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@@ -132,8 +130,7 @@ func Dial(
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
if err := sess.WaitUntilHandshakeComplete(); err != nil {
return nil, err
}
return sess, nil
@@ -181,11 +178,37 @@ func populateClientConfig(config *Config) *Config {
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
if err := c.createNewSession(nil); err != nil {
return err
}
go c.listen()
var runErr error
errorChan := make(chan struct{})
go func() {
// session.run() returns as soon as the session is closed
for {
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
continue
}
break
}
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-c.errorChan:
return c.listenErr
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-errorChan:
return runErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
@@ -263,6 +286,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
if hdr.VersionFlag {
@@ -298,7 +322,6 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
// switch to negotiated version
c.version = newVersion
c.versionNegotiated = true
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
@@ -306,7 +329,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
// create a new session and close the old one
// the new session must be created first to update client member variables
oldSession := c.session
defer oldSession.Close(errCloseSessionForNewVersion)
return c.createNewSession(hdr.SupportedVersions)
}
@@ -321,21 +347,5 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.config,
negotiatedVersions,
)
if err != nil {
return err
}
go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.listenErr = err
close(c.errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil
}

View File

@@ -45,7 +45,7 @@ var _ = Describe("Client", func() {
session: sess,
version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr},
errorChan: make(chan struct{}),
versionNegotiationChan: make(chan struct{}),
}
})
@@ -61,9 +61,11 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
var acceptClientVersionPacket []byte
BeforeEach(func() {
newClientSession = func(
_ connection,
conn connection,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
@@ -71,11 +73,23 @@ var _ = Describe("Client", func() {
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
// Expect(err).ToNot(HaveOccurred())
return sess, sess.handshakeChan, nil
}
// accept the QUIC version suggested by the client
b := &bytes.Buffer{}
err := (&wire.PublicHeader{
ConnectionID: 0x1337,
PacketNumber: 1,
PacketNumberLen: 1,
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
acceptClientVersionPacket = b.Bytes()
})
It("dials non-forward-secure", func(done Done) {
packetConn.dataToRead = acceptClientVersionPacket
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -91,10 +105,27 @@ var _ = Describe("Client", func() {
})
It("dials a non-forward-secure address", func(done Done) {
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
server, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
go func() {
defer GinkgoRecover()
for {
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
if err != nil {
return
}
_, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr)
Expect(err).ToNot(HaveOccurred())
}
}()
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialAddrNonFWSecure("localhost:18901", nil, config)
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
@@ -106,6 +137,7 @@ var _ = Describe("Client", func() {
})
It("Dial only returns after the handshake is complete", func(done Done) {
packetConn.dataToRead = acceptClientVersionPacket
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -173,12 +205,23 @@ var _ = Describe("Client", func() {
close(done)
})
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
It("returns an error that occurs during version negotiation", func(done Done) {
testErr := errors.New("early handshake error")
var dialErr error
go func() {
defer GinkgoRecover()
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
close(done)
}()
sess.Close(testErr)
})
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
testErr := errors.New("early handshake error")
packetConn.dataToRead = acceptClientVersionPacket
go func() {
defer GinkgoRecover()
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
close(done)
}()
@@ -187,9 +230,9 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
testErr := errors.New("late handshake error")
var dialErr error
packetConn.dataToRead = acceptClientVersionPacket
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
close(done)
}()
@@ -254,10 +297,20 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
cl.handlePacket(nil, b.Bytes())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(cl.versionNegotiationChan).To(BeClosed())
})
It("changes the version after receiving a version negotiation packet", func() {
var negotiatedVersions []protocol.VersionNumber
newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
packetConn.dataToRead = wire.ComposeVersionNegotiation(
0x1337,
[]protocol.VersionNumber{newVersion},
)
sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent)
newClientSession = func(
_ connection,
_ string,
@@ -268,25 +321,38 @@ var _ = Describe("Client", func() {
negotiatedVersionsP []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
negotiatedVersions = negotiatedVersionsP
return &mockSession{
// make the server accept the new version
if len(negotiatedVersionsP) > 0 {
packetConn.dataToRead = acceptClientVersionPacket
}
sess := &mockSession{
connectionID: connectionID,
}, nil, nil
stopRunLoop: make(chan struct{}),
}
sessionChan <- sess
return sess, handshakeChan, nil
}
newVersion := protocol.VersionNumber(77)
Expect(config.Versions).To(ContainElement(newVersion))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero())
cl.connectionID = 0x1337
cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(cl.version).To(Equal(newVersion))
Expect(cl.versionNegotiated).To(BeTrue())
// it swapped the sessions
// Expect(cl.session).ToNot(Equal(sess))
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
established := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cl.establishSecureConnection()
Expect(err).ToNot(HaveOccurred())
close(established)
}()
var firstSession, secondSession *mockSession
Eventually(sessionChan).Should(Receive(&firstSession))
Eventually(sessionChan).Should(Receive(&secondSession))
// it didn't pass the version negoation packet to the old session (since it has no payload)
Expect(sess.packetCount).To(BeZero())
Expect(firstSession.packetCount).To(BeZero())
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion}))
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(established).Should(BeClosed())
})
It("errors if no matching version is found", func() {