diff --git a/client.go b/client.go index 9aa4fd8b1..c37665099 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "crypto/tls" "errors" "net" "strings" @@ -21,14 +20,11 @@ type client struct { conn connection hostname string - config *Config + config *Config + connState ConnState - connectionID protocol.ConnectionID - version protocol.VersionNumber - versionNegotiated bool - - tlsConfig *tls.Config - cryptoChangeCallback CryptoChangeCallback + connectionID protocol.ConnectionID + version protocol.VersionNumber session packetHandler } @@ -61,19 +57,6 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config c.connStateChangeCond.L = &c.mutex - c.cryptoChangeCallback = func(isForwardSecure bool) { - var state ConnState - if isForwardSecure { - state = ConnStateForwardSecure - } else { - state = ConnStateSecure - } - - if c.config.ConnState != nil { - go config.ConnState(c.session, state) - } - } - err = c.createNewSession(nil) if err != nil { return nil, err @@ -84,7 +67,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config go c.listen() c.mutex.Lock() - for !c.versionNegotiated { + for { + if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated { + break + } + if c.config.ConnState == nil && c.connState == ConnStateForwardSecure { + break + } c.connStateChangeCond.Wait() } c.mutex.Unlock() @@ -147,15 +136,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { hdr.Raw = packet[:len(packet)-r.Len()] // ignore delayed / duplicated version negotiation packets - if c.versionNegotiated && hdr.VersionFlag { + if c.connState >= ConnStateVersionNegotiated && hdr.VersionFlag { return nil } // this is the first packet after the client sent a packet with the VersionFlag set // if the server doesn't send a version negotiation packet, it supports the suggested version - if !hdr.VersionFlag && !c.versionNegotiated { + if !hdr.VersionFlag && c.connState == ConnStateInitial { c.mutex.Lock() - c.versionNegotiated = true + c.connState = ConnStateVersionNegotiated c.connStateChangeCond.Signal() c.mutex.Unlock() if c.config.ConnState != nil { @@ -186,7 +175,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { // switch to negotiated version c.version = highestSupportedVersion - c.versionNegotiated = true + c.connState = ConnStateVersionNegotiated c.connectionID, err = utils.GenerateConnectionID() if err != nil { return err @@ -217,6 +206,19 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { return nil } +func (c *client) cryptoChangeCallback(isForwardSecure bool) { + var state ConnState + if isForwardSecure { + state = ConnStateForwardSecure + } else { + state = ConnStateSecure + } + + if c.config.ConnState != nil { + go c.config.ConnState(c.session, state) + } +} + func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { var err error c.session, err = newClientSession( diff --git a/client_test.go b/client_test.go index b9f0e3923..32b20c6c6 100644 --- a/client_test.go +++ b/client_test.go @@ -57,6 +57,21 @@ var _ = Describe("Client", func() { Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io")) }) + // TODO: actually test this + // now we're only testing that Dial doesn't return directly after version negotiation + It("only returns once a forward-secure connection is established if no ConnState is defined", func() { + packetConn.dataToRead = []byte{0x0, 0x1, 0x0} + config.ConnState = nil + var dialReturned bool + go func() { + defer GinkgoRecover() + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).ToNot(HaveOccurred()) + dialReturned = true + }() + Consistently(func() bool { return dialReturned }).Should(BeFalse()) + }) + It("errors on invalid public header", func() { err := cl.handlePacket(nil, nil) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) @@ -175,7 +190,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) err = cl.handlePacket(nil, b.Bytes()) Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated).To(BeTrue()) + Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) }) @@ -186,7 +201,7 @@ var _ = Describe("Client", func() { cl.connectionID = 0x1337 err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion})) Expect(cl.version).To(Equal(newVersion)) - Expect(cl.versionNegotiated).To(BeTrue()) + Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) // it swapped the sessions Expect(cl.session).ToNot(Equal(sess)) @@ -204,11 +219,11 @@ var _ = Describe("Client", func() { It("ignores delayed version negotiation packets", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test - cl.versionNegotiated = true + cl.connState = ConnStateVersionNegotiated Expect(sess.packetCount).To(BeZero()) err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1})) Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated).To(BeTrue()) + Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Expect(sess.packetCount).To(BeZero()) Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse()) }) diff --git a/interface.go b/interface.go index f71f649de..24a0b5d49 100644 --- a/interface.go +++ b/interface.go @@ -36,8 +36,10 @@ type Session interface { type ConnState int const ( + // ConnStateInitial is the initial state + ConnStateInitial ConnState = iota // ConnStateVersionNegotiated means that version negotiation is complete - ConnStateVersionNegotiated ConnState = iota + ConnStateVersionNegotiated // ConnStateSecure means that the connection is encrypted ConnStateSecure // ConnStateForwardSecure means that the connection is forward secure diff --git a/session.go b/session.go index 966ababe0..7f751aaeb 100644 --- a/session.go +++ b/session.go @@ -35,9 +35,9 @@ var ( errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.") ) -// CryptoChangeCallback is called every time the encryption level changes +// cryptoChangeCallback is called every time the encryption level changes // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that -type CryptoChangeCallback func(isForwardSecure bool) +type cryptoChangeCallback func(isForwardSecure bool) // closeCallback is called when a session is closed type closeCallback func(id protocol.ConnectionID) @@ -49,7 +49,7 @@ type session struct { version protocol.VersionNumber closeCallback closeCallback - cryptoChangeCallback CryptoChangeCallback + cryptoChangeCallback cryptoChangeCallback conn connection @@ -132,7 +132,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return s, err } -func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { +func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { s := &session{ conn: conn, connectionID: connectionID,