forked from quic-go/quic-go
add a callback that is called when the crypto level changes
This commit is contained in:
@@ -24,6 +24,7 @@ type Client struct {
|
||||
version protocol.VersionNumber
|
||||
versionNegotiated bool
|
||||
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
versionNegotiateCallback VersionNegotiateCallback
|
||||
|
||||
session packetHandler
|
||||
@@ -40,7 +41,7 @@ var (
|
||||
)
|
||||
|
||||
// NewClient makes a new client
|
||||
func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
|
||||
func NewClient(addr string, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
|
||||
hostname, err := utils.HostnameFromAddr(addr)
|
||||
if err != nil || len(hostname) == 0 {
|
||||
return nil, errHostname
|
||||
@@ -72,6 +73,7 @@ func NewClient(addr string, versionNegotiateCallback VersionNegotiateCallback) (
|
||||
hostname: hostname,
|
||||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
connectionID: connectionID,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
versionNegotiateCallback: versionNegotiateCallback,
|
||||
}
|
||||
|
||||
@@ -186,7 +188,7 @@ func (c *Client) handlePacket(packet []byte) error {
|
||||
|
||||
func (c *Client) createNewSession() error {
|
||||
var err error
|
||||
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback)
|
||||
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ func main() {
|
||||
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
|
||||
client, err := quic.NewClient(addr, func() error { return nil })
|
||||
client, err := quic.NewClient(addr, func(bool) {}, func() error { return nil })
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
|
||||
type mockCryptoSetup struct {
|
||||
diversificationNonce []byte
|
||||
handshakeComplete bool
|
||||
}
|
||||
|
||||
func (m *mockCryptoSetup) HandleCryptoStream() error { panic("not implemented") }
|
||||
func (m *mockCryptoSetup) HandleCryptoStream() error { return nil }
|
||||
|
||||
func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return nil, nil
|
||||
@@ -23,7 +24,7 @@ func (m *mockCryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
||||
}
|
||||
func (m *mockCryptoSetup) LockForSealing() {}
|
||||
func (m *mockCryptoSetup) UnlockForSealing() {}
|
||||
func (m *mockCryptoSetup) HandshakeComplete() bool { panic("not implemented") }
|
||||
func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete }
|
||||
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.diversificationNonce }
|
||||
func (m *mockCryptoSetup) SetDiversificationNonce([]byte) error { panic("not implemented") }
|
||||
|
||||
|
||||
14
session.go
14
session.go
@@ -37,6 +37,10 @@ var (
|
||||
// StreamCallback gets a stream frame and returns a reply frame
|
||||
type StreamCallback func(*Session, utils.Stream)
|
||||
|
||||
// CryptoChangeCallback is called every time the encryption level changes
|
||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type CryptoChangeCallback func(isForwardSecure bool)
|
||||
|
||||
// closeCallback is called when a session is closed
|
||||
type closeCallback func(id protocol.ConnectionID)
|
||||
|
||||
@@ -46,8 +50,9 @@ type Session struct {
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
|
||||
streamCallback StreamCallback
|
||||
closeCallback closeCallback
|
||||
streamCallback StreamCallback
|
||||
closeCallback closeCallback
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
|
||||
conn connection
|
||||
|
||||
@@ -104,6 +109,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
|
||||
streamCallback: streamCallback,
|
||||
closeCallback: closeCallback,
|
||||
cryptoChangeCallback: func(bool) {},
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
|
||||
}
|
||||
|
||||
@@ -121,7 +127,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
return session, err
|
||||
}
|
||||
|
||||
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) {
|
||||
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback) (*Session, error) {
|
||||
session := &Session{
|
||||
conn: &udpConn{conn: conn, currentAddr: addr},
|
||||
connectionID: connectionID,
|
||||
@@ -130,6 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
|
||||
|
||||
streamCallback: streamCallback,
|
||||
closeCallback: closeCallback,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
|
||||
}
|
||||
|
||||
@@ -226,6 +233,7 @@ runLoop:
|
||||
putPacketBuffer(p.publicHeader.Raw)
|
||||
case <-s.aeadChanged:
|
||||
s.tryDecryptingQueuedPackets()
|
||||
s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -157,6 +157,7 @@ var _ = Describe("Session", func() {
|
||||
0,
|
||||
func(*Session, utils.Stream) { streamCallbackCalled = true },
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(isForwardSecure bool) {},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
|
||||
@@ -1031,6 +1032,35 @@ var _ = Describe("Session", func() {
|
||||
Expect(session.receivedPackets).To(Receive())
|
||||
})
|
||||
|
||||
It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) {
|
||||
var callbackCalled bool
|
||||
var callbackCalledWith bool
|
||||
cb := func(p bool) {
|
||||
callbackCalled = true
|
||||
callbackCalledWith = p
|
||||
}
|
||||
session.cryptoChangeCallback = cb
|
||||
session.cryptoSetup = &mockCryptoSetup{handshakeComplete: false}
|
||||
session.aeadChanged <- struct{}{}
|
||||
go session.run()
|
||||
Eventually(func() bool { return callbackCalled }).Should(BeTrue())
|
||||
Expect(callbackCalledWith).To(BeFalse())
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) {
|
||||
var callbackCalledWith bool
|
||||
cb := func(p bool) {
|
||||
callbackCalledWith = p
|
||||
}
|
||||
session.cryptoChangeCallback = cb
|
||||
session.cryptoSetup = &mockCryptoSetup{handshakeComplete: true}
|
||||
session.aeadChanged <- struct{}{}
|
||||
go session.run()
|
||||
Eventually(func() bool { return callbackCalledWith }).Should(BeTrue())
|
||||
close(done)
|
||||
})
|
||||
|
||||
Context("timeouts", func() {
|
||||
It("times out due to no network activity", func(done Done) {
|
||||
session.lastNetworkActivityTime = time.Now().Add(-time.Hour)
|
||||
|
||||
Reference in New Issue
Block a user