forked from quic-go/quic-go
fix version negotiation
This commit is contained in:
74
client.go
74
client.go
@@ -18,17 +18,17 @@ import (
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
listenErr error
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
@@ -106,17 +106,15 @@ func DialNonFWSecure(
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
@@ -132,8 +130,7 @@ func Dial(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sess.WaitUntilHandshakeComplete()
|
||||
if err != nil {
|
||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
@@ -181,11 +178,37 @@ func populateClientConfig(config *Config) *Config {
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
if err := c.createNewSession(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
for {
|
||||
runErr = c.session.run()
|
||||
if runErr == errCloseSessionForNewVersion {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
select {
|
||||
case <-c.errorChan:
|
||||
return c.listenErr
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case <-c.versionNegotiationChan:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
@@ -263,6 +286,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
if hdr.VersionFlag {
|
||||
@@ -298,7 +322,6 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
|
||||
|
||||
// switch to negotiated version
|
||||
c.version = newVersion
|
||||
c.versionNegotiated = true
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
@@ -306,7 +329,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error {
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
// create a new session and close the old one
|
||||
// the new session must be created first to update client member variables
|
||||
oldSession := c.session
|
||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(hdr.SupportedVersions)
|
||||
}
|
||||
|
||||
@@ -321,21 +347,5 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
err := c.session.run()
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return
|
||||
}
|
||||
c.listenErr = err
|
||||
close(c.errorChan)
|
||||
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
110
client_test.go
110
client_test.go
@@ -45,7 +45,7 @@ var _ = Describe("Client", func() {
|
||||
session: sess,
|
||||
version: protocol.SupportedVersions[0],
|
||||
conn: &conn{pconn: packetConn, currentAddr: addr},
|
||||
errorChan: make(chan struct{}),
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
})
|
||||
|
||||
@@ -61,9 +61,11 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
Context("Dialing", func() {
|
||||
var acceptClientVersionPacket []byte
|
||||
|
||||
BeforeEach(func() {
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
conn connection,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -71,11 +73,23 @@ var _ = Describe("Client", func() {
|
||||
_ *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
return sess, sess.handshakeChan, nil
|
||||
}
|
||||
// accept the QUIC version suggested by the client
|
||||
b := &bytes.Buffer{}
|
||||
err := (&wire.PublicHeader{
|
||||
ConnectionID: 0x1337,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
acceptClientVersionPacket = b.Bytes()
|
||||
})
|
||||
|
||||
It("dials non-forward-secure", func(done Done) {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -91,10 +105,27 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("dials a non-forward-secure address", func(done Done) {
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server, err := net.ListenUDP("udp", serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := DialAddrNonFWSecure("localhost:18901", nil, config)
|
||||
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
@@ -106,6 +137,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("Dial only returns after the handshake is complete", func(done Done) {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -173,12 +205,23 @@ var _ = Describe("Client", func() {
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
|
||||
It("returns an error that occurs during version negotiation", func(done Done) {
|
||||
testErr := errors.New("early handshake error")
|
||||
var dialErr error
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(dialErr).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.Close(testErr)
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
|
||||
testErr := errors.New("early handshake error")
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(dialErr).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
@@ -187,9 +230,9 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
|
||||
testErr := errors.New("late handshake error")
|
||||
var dialErr error
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
go func() {
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(dialErr).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
@@ -254,10 +297,20 @@ var _ = Describe("Client", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handlePacket(nil, b.Bytes())
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
Expect(cl.versionNegotiationChan).To(BeClosed())
|
||||
})
|
||||
|
||||
It("changes the version after receiving a version negotiation packet", func() {
|
||||
var negotiatedVersions []protocol.VersionNumber
|
||||
newVersion := protocol.VersionNumber(77)
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
packetConn.dataToRead = wire.ComposeVersionNegotiation(
|
||||
0x1337,
|
||||
[]protocol.VersionNumber{newVersion},
|
||||
)
|
||||
sessionChan := make(chan *mockSession)
|
||||
handshakeChan := make(chan handshakeEvent)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ string,
|
||||
@@ -268,25 +321,38 @@ var _ = Describe("Client", func() {
|
||||
negotiatedVersionsP []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
negotiatedVersions = negotiatedVersionsP
|
||||
return &mockSession{
|
||||
// make the server accept the new version
|
||||
if len(negotiatedVersionsP) > 0 {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
}
|
||||
sess := &mockSession{
|
||||
connectionID: connectionID,
|
||||
}, nil, nil
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
sessionChan <- sess
|
||||
return sess, handshakeChan, nil
|
||||
}
|
||||
|
||||
newVersion := protocol.VersionNumber(77)
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
cl.connectionID = 0x1337
|
||||
cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
||||
Expect(cl.version).To(Equal(newVersion))
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
// it swapped the sessions
|
||||
// Expect(cl.session).ToNot(Equal(sess))
|
||||
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
|
||||
established := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cl.establishSecureConnection()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(established)
|
||||
}()
|
||||
var firstSession, secondSession *mockSession
|
||||
Eventually(sessionChan).Should(Receive(&firstSession))
|
||||
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||
// it didn't pass the version negoation packet to the old session (since it has no payload)
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
Expect(firstSession.packetCount).To(BeZero())
|
||||
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
|
||||
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
|
||||
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
|
||||
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
|
||||
Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion}))
|
||||
|
||||
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Eventually(established).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
|
||||
Reference in New Issue
Block a user