From ab613c3ab0d1f7749a86f683cb8a76fd48fa4577 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 8 May 2017 20:41:05 +0800 Subject: [PATCH] use a directed aeadChanged chan in the session --- session.go | 10 +++++++--- session_test.go | 28 ++++++++++++++++------------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/session.go b/session.go index 6e9ecf66..55bc649b 100644 --- a/session.go +++ b/session.go @@ -78,7 +78,7 @@ type session struct { // this channel is passed to the CryptoSetup and receives the current encryption level // it is closed as soon as the handshake is complete - aeadChanged chan protocol.EncryptionLevel + aeadChanged <-chan protocol.EncryptionLevel handshakeComplete bool nextAckScheduledTime time.Time @@ -121,8 +121,10 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } else { sourceAddr = []byte(conn.RemoteAddr().String()) } + aeadChanged := make(chan protocol.EncryptionLevel, 2) + s.aeadChanged = aeadChanged var err error - s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, s.aeadChanged) + s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, aeadChanged) if err != nil { return nil, err } @@ -147,9 +149,11 @@ func newClientSession(conn connection, hostname string, v protocol.VersionNumber s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) s.setup() + aeadChanged := make(chan protocol.EncryptionLevel, 2) + s.aeadChanged = aeadChanged cryptoStream, _ := s.OpenStream() var err error - s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, s.aeadChanged, negotiatedVersions) + s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, aeadChanged, negotiatedVersions) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 702eed20..2a4a7eeb 100644 --- a/session_test.go +++ b/session_test.go @@ -125,11 +125,12 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - clientSess *session - scfg *handshake.ServerConfig - mconn *mockConnection - cpm *mockConnectionParametersManager + sess *session + clientSess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cpm *mockConnectionParametersManager + aeadChanged chan<- protocol.EncryptionLevel ) BeforeEach(func() { @@ -154,6 +155,9 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + // we need an aeadChanged chan that we can write to + // since type assertions on chans are not possible, we have to extract it from the CryptoSetup + aeadChanged = *(*chan<- protocol.EncryptionLevel)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("aeadChanged").UnsafeAddr())) cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} sess.connectionParameters = cpm @@ -1166,9 +1170,9 @@ var _ = Describe("Session", func() { It("tells the packetPacker when forward-secure encryption is used", func() { go sess.run() defer sess.Close(nil) - sess.aeadChanged <- protocol.EncryptionSecure + aeadChanged <- protocol.EncryptionSecure Consistently(func() bool { return sess.packer.isForwardSecure }).Should(BeFalse()) - sess.aeadChanged <- protocol.EncryptionForwardSecure + aeadChanged <- protocol.EncryptionForwardSecure Eventually(func() bool { return sess.packer.isForwardSecure }).Should(BeTrue()) }) @@ -1251,7 +1255,7 @@ var _ = Describe("Session", func() { }) It("ignores undecryptable packets after the handshake is complete", func() { - close(sess.aeadChanged) + close(aeadChanged) go sess.run() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) @@ -1280,7 +1284,7 @@ var _ = Describe("Session", func() { callbackSession = s } sess.cryptoChangeCallback = cb - sess.aeadChanged <- protocol.EncryptionSecure + aeadChanged <- protocol.EncryptionSecure go sess.run() defer sess.Close(nil) Eventually(func() bool { return callbackCalled }).Should(BeTrue()) @@ -1297,8 +1301,8 @@ var _ = Describe("Session", func() { callbackCalledWith = p } sess.cryptoChangeCallback = cb - sess.aeadChanged <- protocol.EncryptionForwardSecure - close(sess.aeadChanged) + aeadChanged <- protocol.EncryptionForwardSecure + close(aeadChanged) go sess.run() defer sess.Close(nil) Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) @@ -1334,7 +1338,7 @@ var _ = Describe("Session", func() { }) It("uses ICSL after handshake", func(done Done) { - close(sess.aeadChanged) + close(aeadChanged) cpm.idleTime = 0 * time.Millisecond sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return