add a callback that is called when the crypto level changes

This commit is contained in:
Marten Seemann
2016-12-15 13:35:55 +07:00
parent 4130c3de32
commit a7afd72795
5 changed files with 49 additions and 8 deletions

View File

@@ -24,6 +24,7 @@ type Client struct {
version protocol.VersionNumber
versionNegotiated bool
cryptoChangeCallback CryptoChangeCallback
versionNegotiateCallback VersionNegotiateCallback
session packetHandler
@@ -40,7 +41,7 @@ var (
)
// NewClient makes a new client
func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
func NewClient(addr string, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
hostname, err := utils.HostnameFromAddr(addr)
if err != nil || len(hostname) == 0 {
return nil, errHostname
@@ -72,6 +73,7 @@ func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
cryptoChangeCallback: cryptoChangeCallback,
versionNegotiateCallback: versionNegotiateCallback,
}
@@ -186,7 +188,7 @@ func (c *Client) handlePacket(packet []byte) error {
func (c *Client) createNewSession() error {
var err error
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback)
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback)
if err != nil {
return err
}

View File

@@ -10,7 +10,7 @@ func main() {
utils.SetLogLevel(utils.LogLevelDebug)
client, err := quic.NewClient(addr, func() error { return nil })
client, err := quic.NewClient(addr, func(bool) {}, func() error { return nil })
if err != nil {
panic(err)
}

View File

@@ -11,9 +11,10 @@ import (
type mockCryptoSetup struct {
diversificationNonce []byte
handshakeComplete bool
}
func (m *mockCryptoSetup) HandleCryptoStream() error { panic("not implemented") }
func (m *mockCryptoSetup) HandleCryptoStream() error { return nil }
func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return nil, nil
@@ -23,7 +24,7 @@ func (m *mockCryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumb
}
func (m *mockCryptoSetup) LockForSealing() {}
func (m *mockCryptoSetup) UnlockForSealing() {}
func (m *mockCryptoSetup) HandshakeComplete() bool { panic("not implemented") }
func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete }
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.diversificationNonce }
func (m *mockCryptoSetup) SetDiversificationNonce([]byte) error { panic("not implemented") }

View File

@@ -37,6 +37,10 @@ var (
// StreamCallback gets a stream frame and returns a reply frame
type StreamCallback func(*Session, utils.Stream)
// CryptoChangeCallback is called every time the encryption level changes
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type CryptoChangeCallback func(isForwardSecure bool)
// closeCallback is called when a session is closed
type closeCallback func(id protocol.ConnectionID)
@@ -46,8 +50,9 @@ type Session struct {
perspective protocol.Perspective
version protocol.VersionNumber
streamCallback StreamCallback
closeCallback closeCallback
streamCallback StreamCallback
closeCallback closeCallback
cryptoChangeCallback CryptoChangeCallback
conn connection
@@ -104,6 +109,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
streamCallback: streamCallback,
closeCallback: closeCallback,
cryptoChangeCallback: func(bool) {},
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
}
@@ -121,7 +127,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return session, err
}
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) {
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback) (*Session, error) {
session := &Session{
conn: &udpConn{conn: conn, currentAddr: addr},
connectionID: connectionID,
@@ -130,6 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
streamCallback: streamCallback,
closeCallback: closeCallback,
cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
}
@@ -226,6 +233,7 @@ runLoop:
putPacketBuffer(p.publicHeader.Raw)
case <-s.aeadChanged:
s.tryDecryptingQueuedPackets()
s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete())
}
if err != nil {

View File

@@ -157,6 +157,7 @@ var _ = Describe("Session", func() {
0,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(isForwardSecure bool) {},
)
Expect(err).ToNot(HaveOccurred())
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
@@ -1031,6 +1032,35 @@ var _ = Describe("Session", func() {
Expect(session.receivedPackets).To(Receive())
})
It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) {
var callbackCalled bool
var callbackCalledWith bool
cb := func(p bool) {
callbackCalled = true
callbackCalledWith = p
}
session.cryptoChangeCallback = cb
session.cryptoSetup = &mockCryptoSetup{handshakeComplete: false}
session.aeadChanged <- struct{}{}
go session.run()
Eventually(func() bool { return callbackCalled }).Should(BeTrue())
Expect(callbackCalledWith).To(BeFalse())
close(done)
})
It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) {
var callbackCalledWith bool
cb := func(p bool) {
callbackCalledWith = p
}
session.cryptoChangeCallback = cb
session.cryptoSetup = &mockCryptoSetup{handshakeComplete: true}
session.aeadChanged <- struct{}{}
go session.run()
Eventually(func() bool { return callbackCalledWith }).Should(BeTrue())
close(done)
})
Context("timeouts", func() {
It("times out due to no network activity", func(done Done) {
session.lastNetworkActivityTime = time.Now().Add(-time.Hour)