diff --git a/client.go b/client.go index 882a1b580..30bd71ed4 100644 --- a/client.go +++ b/client.go @@ -24,9 +24,14 @@ type Client struct { version protocol.VersionNumber versionNegotiated bool + versionNegotiateCallback VersionNegotiateCallback + session packetHandler } +// VersionNegotiateCallback is called once the client has a negotiated version +type VersionNegotiateCallback func() error + var errHostname = errors.New("Invalid hostname") var ( @@ -35,7 +40,7 @@ var ( ) // NewClient makes a new client -func NewClient(addr string) (*Client, error) { +func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { hostname, err := utils.HostnameFromAddr(addr) if err != nil || len(hostname) == 0 { return nil, errHostname @@ -62,11 +67,12 @@ func NewClient(addr string) (*Client, error) { connectionID := protocol.ConnectionID(rand.Int63()) client := &Client{ - addr: udpAddr, - conn: conn, - hostname: hostname, - version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default - connectionID: connectionID, + addr: udpAddr, + conn: conn, + hostname: hostname, + version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default + connectionID: connectionID, + versionNegotiateCallback: versionNegotiateCallback, } utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) @@ -133,6 +139,10 @@ func (c *Client) handlePacket(packet []byte) error { // if the server doesn't send a version negotiation packet, it supports the suggested version if !hdr.VersionFlag && !c.versionNegotiated { c.versionNegotiated = true + err = c.versionNegotiateCallback() + if err != nil { + return err + } } if hdr.VersionFlag { @@ -151,11 +161,16 @@ func (c *Client) handlePacket(packet []byte) error { utils.Infof("Switching to QUIC version %d", highestSupportedVersion) c.version = highestSupportedVersion c.versionNegotiated = true + c.session.Close(errCloseSessionForNewVersion) err = c.createNewSession() if err != nil { return err } + err = c.versionNegotiateCallback() + if err != nil { + return err + } return nil // version negotiation packets have no payload } diff --git a/client_test.go b/client_test.go index b546c1d1c..381473e60 100644 --- a/client_test.go +++ b/client_test.go @@ -13,11 +13,20 @@ import ( ) var _ = Describe("Client", func() { - var client *Client - var session *mockSession + var ( + client *Client + session *mockSession + versionNegotiateCallbackCalled bool + ) BeforeEach(func() { - client = &Client{} + versionNegotiateCallbackCalled = false + client = &Client{ + versionNegotiateCallback: func() error { + versionNegotiateCallbackCalled = true + return nil + }, + } session = &mockSession{connectionID: 0x1337} client.connectionID = 0x1337 client.session = session @@ -162,6 +171,7 @@ var _ = Describe("Client", func() { err = client.handlePacket(b.Bytes()) Expect(err).ToNot(HaveOccurred()) Expect(client.versionNegotiated).To(BeTrue()) + Expect(versionNegotiateCallbackCalled).To(BeTrue()) }) It("changes the version after receiving a version negotiation packet", func() { @@ -172,6 +182,7 @@ var _ = Describe("Client", func() { err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{newVersion})) Expect(client.version).To(Equal(newVersion)) Expect(client.versionNegotiated).To(BeTrue()) + Expect(versionNegotiateCallbackCalled).To(BeTrue()) // it swapped the sessions Expect(client.session).ToNot(Equal(session)) Expect(err).ToNot(HaveOccurred()) @@ -195,6 +206,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) Expect(client.versionNegotiated).To(BeTrue()) Expect(session.packetCount).To(BeZero()) + Expect(versionNegotiateCallbackCalled).To(BeFalse()) }) It("errors if the server should have accepted the offered version", func() { diff --git a/example/client/client.go b/example/client/client.go index 4898fb61c..3cf20ea8f 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -10,7 +10,7 @@ func main() { utils.SetLogLevel(utils.LogLevelDebug) - client, err := quic.NewClient(addr) + client, err := quic.NewClient(addr, func() error { return nil }) if err != nil { panic(err) }