From a7afd7279564926318316d7b22806496a8d723f5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 15 Dec 2016 13:35:55 +0700 Subject: [PATCH] add a callback that is called when the crypto level changes --- client.go | 6 ++++-- example/client/client.go | 2 +- packet_packer_test.go | 5 +++-- session.go | 14 +++++++++++--- session_test.go | 30 ++++++++++++++++++++++++++++++ 5 files changed, 49 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 30bd71ed..0e3e873e 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,7 @@ type Client struct { version protocol.VersionNumber versionNegotiated bool + cryptoChangeCallback CryptoChangeCallback versionNegotiateCallback VersionNegotiateCallback session packetHandler @@ -40,7 +41,7 @@ var ( ) // NewClient makes a new client -func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { +func NewClient(addr string, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { hostname, err := utils.HostnameFromAddr(addr) if err != nil || len(hostname) == 0 { return nil, errHostname @@ -72,6 +73,7 @@ func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) ( hostname: hostname, version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default connectionID: connectionID, + cryptoChangeCallback: cryptoChangeCallback, versionNegotiateCallback: versionNegotiateCallback, } @@ -186,7 +188,7 @@ func (c *Client) handlePacket(packet []byte) error { func (c *Client) createNewSession() error { var err error - c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback) + c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback) if err != nil { return err } diff --git a/example/client/client.go b/example/client/client.go index 3cf20ea8..d0966d82 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -10,7 +10,7 @@ func main() { utils.SetLogLevel(utils.LogLevelDebug) - client, err := quic.NewClient(addr, func() error { return nil }) + client, err := quic.NewClient(addr, func(bool) {}, func() error { return nil }) if err != nil { panic(err) } diff --git a/packet_packer_test.go b/packet_packer_test.go index da219322..4372beee 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -11,9 +11,10 @@ import ( type mockCryptoSetup struct { diversificationNonce []byte + handshakeComplete bool } -func (m *mockCryptoSetup) HandleCryptoStream() error { panic("not implemented") } +func (m *mockCryptoSetup) HandleCryptoStream() error { return nil } func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { return nil, nil @@ -23,7 +24,7 @@ func (m *mockCryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumb } func (m *mockCryptoSetup) LockForSealing() {} func (m *mockCryptoSetup) UnlockForSealing() {} -func (m *mockCryptoSetup) HandshakeComplete() bool { panic("not implemented") } +func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.diversificationNonce } func (m *mockCryptoSetup) SetDiversificationNonce([]byte) error { panic("not implemented") } diff --git a/session.go b/session.go index 5727291b..33bfa5b4 100644 --- a/session.go +++ b/session.go @@ -37,6 +37,10 @@ var ( // StreamCallback gets a stream frame and returns a reply frame type StreamCallback func(*Session, utils.Stream) +// CryptoChangeCallback is called every time the encryption level changes +// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that +type CryptoChangeCallback func(isForwardSecure bool) + // closeCallback is called when a session is closed type closeCallback func(id protocol.ConnectionID) @@ -46,8 +50,9 @@ type Session struct { perspective protocol.Perspective version protocol.VersionNumber - streamCallback StreamCallback - closeCallback closeCallback + streamCallback StreamCallback + closeCallback closeCallback + cryptoChangeCallback CryptoChangeCallback conn connection @@ -104,6 +109,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol streamCallback: streamCallback, closeCallback: closeCallback, + cryptoChangeCallback: func(bool) {}, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } @@ -121,7 +127,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return session, err } -func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) { +func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback) (*Session, error) { session := &Session{ conn: &udpConn{conn: conn, currentAddr: addr}, connectionID: connectionID, @@ -130,6 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p streamCallback: streamCallback, closeCallback: closeCallback, + cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), } @@ -226,6 +233,7 @@ runLoop: putPacketBuffer(p.publicHeader.Raw) case <-s.aeadChanged: s.tryDecryptingQueuedPackets() + s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete()) } if err != nil { diff --git a/session_test.go b/session_test.go index 2e6004eb..5c0bf394 100644 --- a/session_test.go +++ b/session_test.go @@ -157,6 +157,7 @@ var _ = Describe("Session", func() { 0, func(*Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, + func(isForwardSecure bool) {}, ) Expect(err).ToNot(HaveOccurred()) Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream @@ -1031,6 +1032,35 @@ var _ = Describe("Session", func() { Expect(session.receivedPackets).To(Receive()) }) + It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) { + var callbackCalled bool + var callbackCalledWith bool + cb := func(p bool) { + callbackCalled = true + callbackCalledWith = p + } + session.cryptoChangeCallback = cb + session.cryptoSetup = &mockCryptoSetup{handshakeComplete: false} + session.aeadChanged <- struct{}{} + go session.run() + Eventually(func() bool { return callbackCalled }).Should(BeTrue()) + Expect(callbackCalledWith).To(BeFalse()) + close(done) + }) + + It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) { + var callbackCalledWith bool + cb := func(p bool) { + callbackCalledWith = p + } + session.cryptoChangeCallback = cb + session.cryptoSetup = &mockCryptoSetup{handshakeComplete: true} + session.aeadChanged <- struct{}{} + go session.run() + Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) + close(done) + }) + Context("timeouts", func() { It("times out due to no network activity", func(done Done) { session.lastNetworkActivityTime = time.Now().Add(-time.Hour)