From 0a6a9551f873801e67e8205eb463755818ce34cc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 7 May 2017 22:00:57 +0800 Subject: [PATCH] close the aeadChanged chan when the handshake completes This allows us to remove HandshakeComplete() from the CryptoSetup interface. It also provides a signal to the session when the handshake completes. --- handshake/crypto_setup_client.go | 8 +------ handshake/crypto_setup_client_test.go | 32 ++++++++++++++++----------- handshake/crypto_setup_server.go | 10 ++++----- handshake/crypto_setup_server_test.go | 26 +++++++++------------- handshake/interface.go | 1 - packet_packer_test.go | 6 ++--- session.go | 32 ++++++++++++++++++--------- session_test.go | 8 +++---- 8 files changed, 61 insertions(+), 62 deletions(-) diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index fe57d645..44da0433 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -251,6 +251,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { } h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.aeadChanged) return nil } @@ -370,13 +371,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { return nil } -func (h *cryptoSetupClient) HandshakeComplete() bool { - h.mutex.RLock() - defer h.mutex.RUnlock() - - return h.forwardSecureAEAD != nil -} - func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 26a309dc..e28ab12f 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -369,22 +369,22 @@ var _ = Describe("Client Crypto Setup", func() { cs.receivedSecurePacket = false err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) - Expect(cs.HandshakeComplete()).To(BeFalse()) Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("rejects SHLOs without a PUBS", func() { delete(shloMap, TagPUBS) err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) - Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("rejects SHLOs without a version list", func() { delete(shloMap, TagVER) err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) - Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("accepts a SHLO after a version negotiation", func() { @@ -409,8 +409,8 @@ var _ = Describe("Client Crypto Setup", func() { err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) Expect(cs.forwardSecureAEAD).ToNot(BeNil()) - Expect(cs.HandshakeComplete()).To(BeTrue()) - Expect(cs.aeadChanged).To(Receive()) + Expect(cs.aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) + Expect(cs.aeadChanged).To(BeClosed()) }) It("reads the connection paramaters", func() { @@ -598,8 +598,9 @@ var _ = Describe("Client Crypto Setup", func() { Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert)) Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) - Expect(cs.HandshakeComplete()).To(BeFalse()) - Expect(cs.aeadChanged).To(Receive()) + Expect(cs.aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("uses the server nonce, if the server sent one", func() { @@ -609,8 +610,9 @@ var _ = Describe("Client Crypto Setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) - Expect(cs.HandshakeComplete()).To(BeFalse()) Expect(cs.aeadChanged).To(Receive()) + Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { @@ -623,8 +625,9 @@ var _ = Describe("Client Crypto Setup", func() { err = cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(cs.HandshakeComplete()).To(BeFalse()) - Expect(cs.aeadChanged).To(Receive()) + Expect(cs.aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("tries to escalate before reading a handshake message", func() { @@ -635,7 +638,9 @@ var _ = Describe("Client Crypto Setup", func() { // this is because the mockStream doesn't block if there's no data to read Expect(err).To(MatchError(qerr.HandshakeFailed)) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) It("tries to escalate the crypto after receiving a diversification nonce", func() { @@ -645,8 +650,9 @@ var _ = Describe("Client Crypto Setup", func() { err := cs.SetDiversificationNonce([]byte("div")) Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(cs.aeadChanged).To(Receive()) - Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(cs.aeadChanged).ToNot(Receive()) + Expect(cs.aeadChanged).ToNot(BeClosed()) }) Context("null encryption", func() { diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 184ba609..c046c123 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -173,7 +173,10 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if h.forwardSecureAEAD != nil { res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { - h.receivedForwardSecurePacket = true + if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client + h.receivedForwardSecurePacket = true + close(h.aeadChanged) + } return res, protocol.EncryptionForwardSecure, nil } if h.receivedForwardSecurePacket { @@ -428,11 +431,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error { panic("not needed for cryptoSetupServer") } -// HandshakeComplete returns true after the first forward secure packet was received form the client. -func (h *cryptoSetupServer) HandshakeComplete() bool { - return h.receivedForwardSecurePacket -} - func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 1670c324..6f2ab0d9 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -302,8 +302,11 @@ var _ = Describe("Server Crypto Setup", func() { err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) + Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) - Expect(aeadChanged).To(Receive()) + Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) + Expect(aeadChanged).ToNot(Receive()) + Expect(aeadChanged).ToNot(BeClosed()) }) It("rejects client nonces that have the wrong length", func() { @@ -334,11 +337,8 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) - var encLevel protocol.EncryptionLevel - Expect(aeadChanged).To(Receive(&encLevel)) - Expect(encLevel).To(Equal(protocol.EncryptionSecure)) - Expect(aeadChanged).To(Receive(&encLevel)) - Expect(encLevel).To(Equal(protocol.EncryptionForwardSecure)) + Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) }) It("recognizes inchoate CHLOs missing SCID", func() { @@ -593,15 +593,11 @@ var _ = Describe("Server Crypto Setup", func() { It("regards the handshake as complete once it receives a forward encrypted packet", func() { doCHLO() - enc, seal := cs.GetSealer() - Expect(enc).To(Equal(protocol.EncryptionSecure)) - _ = seal(nil, []byte("SHLO"), 0, []byte{}) - enc, seal = cs.GetSealer() - Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) - _ = seal(nil, []byte("foobar"), 0, []byte{}) - Expect(cs.HandshakeComplete()).To(BeFalse()) - cs.receivedForwardSecurePacket = true - Expect(cs.HandshakeComplete()).To(BeTrue()) + _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{}) + Expect(err).ToNot(HaveOccurred()) + Expect(cs.aeadChanged).To(Receive()) // consume the protocol.EncryptionSecure + Expect(cs.aeadChanged).To(Receive()) // consume the protocol.EncryptionForwardSecure + Expect(cs.aeadChanged).To(BeClosed()) }) }) diff --git a/handshake/interface.go b/handshake/interface.go index a6629cc4..fb5c86fc 100644 --- a/handshake/interface.go +++ b/handshake/interface.go @@ -9,7 +9,6 @@ type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associated type CryptoSetup interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) HandleCryptoStream() error - HandshakeComplete() bool // TODO: clean up this interface DiversificationNonce() []byte // only needed for cryptoSetupServer SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient diff --git a/packet_packer_test.go b/packet_packer_test.go index 4d1c6df8..efa50315 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -13,9 +13,8 @@ import ( ) type mockCryptoSetup struct { - divNonce []byte - handshakeComplete bool - encLevelSeal protocol.EncryptionLevel + divNonce []byte + encLevelSeal protocol.EncryptionLevel } func (m *mockCryptoSetup) HandleCryptoStream() error { return nil } @@ -33,7 +32,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) return append(src, bytes.Repeat([]byte{0}, 12)...) }, nil } -func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } diff --git a/session.go b/session.go index 3b9ba3c7..6e9ecf66 100644 --- a/session.go +++ b/session.go @@ -76,7 +76,10 @@ type session struct { undecryptablePackets []*receivedPacket receivedTooManyUndecrytablePacketsTime time.Time - aeadChanged chan protocol.EncryptionLevel + // this channel is passed to the CryptoSetup and receives the current encryption level + // it is closed as soon as the handshake is complete + aeadChanged chan protocol.EncryptionLevel + handshakeComplete bool nextAckScheduledTime time.Time @@ -98,7 +101,7 @@ type session struct { var _ Session = &session{} // newSession makes a new session -func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) { +func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) { s := &session{ conn: conn, connectionID: connectionID, @@ -195,6 +198,8 @@ func (s *session) run() error { }() var closeErr error + aeadChanged := s.aeadChanged + runLoop: for { // Close immediately if requested @@ -226,12 +231,17 @@ runLoop: // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) - case l := <-s.aeadChanged: - if l == protocol.EncryptionForwardSecure { - s.packer.SetForwardSecure() + case l, ok := <-aeadChanged: + if !ok { + s.handshakeComplete = true + aeadChanged = nil // prevent this case from ever being selected again + } else { + if l == protocol.EncryptionForwardSecure { + s.packer.SetForwardSecure() + } + s.tryDecryptingQueuedPackets() + s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) } - s.tryDecryptingQueuedPackets() - s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) } if err != nil { @@ -254,7 +264,7 @@ runLoop: if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } - if !s.cryptoSetup.HandshakeComplete() && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { + if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) } s.garbageCollectStreams() @@ -273,7 +283,7 @@ func (s *session) maybeResetTimer() { if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { nextDeadline = utils.MinTime(nextDeadline, lossTime) } - if !s.cryptoSetup.HandshakeComplete() { + if !s.handshakeComplete { handshakeDeadline := s.sessionCreationTime.Add(protocol.MaxTimeForCryptoHandshake) nextDeadline = utils.MinTime(nextDeadline, handshakeDeadline) } @@ -298,7 +308,7 @@ func (s *session) maybeResetTimer() { } func (s *session) idleTimeout() time.Duration { - if s.cryptoSetup.HandshakeComplete() { + if s.handshakeComplete { return s.connectionParameters.GetIdleConnectionStateLifetime() } return protocol.InitialIdleTimeout @@ -753,7 +763,7 @@ func (s *session) scheduleSending() { } func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { - if s.cryptoSetup.HandshakeComplete() { + if s.handshakeComplete { return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { diff --git a/session_test.go b/session_test.go index 9fd25a12..702eed20 100644 --- a/session_test.go +++ b/session_test.go @@ -1251,7 +1251,7 @@ var _ = Describe("Session", func() { }) It("ignores undecryptable packets after the handshake is complete", func() { - sess.cryptoSetup.(*mockCryptoSetup).handshakeComplete = true + close(sess.aeadChanged) go sess.run() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) @@ -1297,8 +1297,8 @@ var _ = Describe("Session", func() { callbackCalledWith = p } sess.cryptoChangeCallback = cb - sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: true} sess.aeadChanged <- protocol.EncryptionForwardSecure + close(sess.aeadChanged) go sess.run() defer sess.Close(nil) Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) @@ -1334,9 +1334,7 @@ var _ = Describe("Session", func() { }) It("uses ICSL after handshake", func(done Done) { - // sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) - *(*bool)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true - *(*crypto.AEAD)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("forwardSecureAEAD").UnsafeAddr())) = crypto.NewNullAEAD(protocol.PerspectiveServer, protocol.VersionWhatever) + close(sess.aeadChanged) cpm.idleTime = 0 * time.Millisecond sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return