Merge pull request #595 from lucas-clemente/remove-handshake-complete

replace CryptoSetup.HandshakeComplete() by closing the aeadChanged channel
This commit is contained in:
Marten Seemann
2017-05-09 19:36:17 +08:00
committed by GitHub
8 changed files with 97 additions and 86 deletions

View File

@@ -47,7 +47,7 @@ type cryptoSetupClient struct {
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
aeadChanged chan protocol.EncryptionLevel
aeadChanged chan<- protocol.EncryptionLevel
connectionParameters ConnectionParametersManager
}
@@ -68,7 +68,7 @@ func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
tlsConfig *tls.Config,
connectionParameters ConnectionParametersManager,
aeadChanged chan protocol.EncryptionLevel,
aeadChanged chan<- protocol.EncryptionLevel,
negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) {
return &cryptoSetupClient{
@@ -251,6 +251,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
}
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
return nil
}
@@ -370,13 +371,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error {
return nil
}
func (h *cryptoSetupClient) HandshakeComplete() bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.forwardSecureAEAD != nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {

View File

@@ -72,11 +72,14 @@ func (m *mockCertManager) Verify(hostname string) error {
}
var _ = Describe("Client Crypto Setup", func() {
var cs *cryptoSetupClient
var certManager *mockCertManager
var stream *mockStream
var keyDerivationCalledWith *keyDerivationValues
var shloMap map[Tag][]byte
var (
cs *cryptoSetupClient
certManager *mockCertManager
stream *mockStream
keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte
aeadChanged chan protocol.EncryptionLevel
)
BeforeEach(func() {
shloMap = map[Tag][]byte{
@@ -101,7 +104,8 @@ var _ = Describe("Client Crypto Setup", func() {
stream = &mockStream{}
certManager = &mockCertManager{}
version := protocol.Version36
csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan protocol.EncryptionLevel, 2), nil)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), aeadChanged, nil)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupClient)
cs.certManager = certManager
@@ -369,22 +373,22 @@ var _ = Describe("Client Crypto Setup", func() {
cs.receivedSecurePacket = false
err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(cs.aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("rejects SHLOs without a PUBS", func() {
delete(shloMap, TagPUBS)
err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(aeadChanged).ToNot(BeClosed())
})
It("rejects SHLOs without a version list", func() {
delete(shloMap, TagVER)
err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(aeadChanged).ToNot(BeClosed())
})
It("accepts a SHLO after a version negotiation", func() {
@@ -409,8 +413,8 @@ var _ = Describe("Client Crypto Setup", func() {
err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
Expect(cs.forwardSecureAEAD).ToNot(BeNil())
Expect(cs.HandshakeComplete()).To(BeTrue())
Expect(cs.aeadChanged).To(Receive())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).To(BeClosed())
})
It("reads the connection paramaters", func() {
@@ -598,8 +602,9 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(cs.aeadChanged).To(Receive())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("uses the server nonce, if the server sent one", func() {
@@ -609,22 +614,24 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(cs.aeadChanged).To(Receive())
Expect(aeadChanged).To(Receive())
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
err := cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).To(BeNil())
Expect(cs.aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(Receive())
cs.serverVerified = true
// make sure we really had all necessary values before, and only serverVerified was missing
err = cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(cs.aeadChanged).To(Receive())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("tries to escalate before reading a handshake message", func() {
@@ -635,7 +642,9 @@ var _ = Describe("Client Crypto Setup", func() {
// this is because the mockStream doesn't block if there's no data to read
Expect(err).To(MatchError(qerr.HandshakeFailed))
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("tries to escalate the crypto after receiving a diversification nonce", func() {
@@ -645,8 +654,9 @@ var _ = Describe("Client Crypto Setup", func() {
err := cs.SetDiversificationNonce([]byte("div"))
Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(cs.aeadChanged).To(Receive())
Expect(cs.HandshakeComplete()).To(BeFalse())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
Context("null encryption", func() {

View File

@@ -36,7 +36,7 @@ type cryptoSetupServer struct {
receivedForwardSecurePacket bool
sentSHLO bool
receivedSecurePacket bool
aeadChanged chan protocol.EncryptionLevel
aeadChanged chan<- protocol.EncryptionLevel
keyDerivation KeyDerivationFunction
keyExchange KeyExchangeFunction
@@ -64,7 +64,7 @@ func NewCryptoSetup(
cryptoStream io.ReadWriter,
connectionParametersManager ConnectionParametersManager,
supportedVersions []protocol.VersionNumber,
aeadChanged chan protocol.EncryptionLevel,
aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) {
return &cryptoSetupServer{
connID: connID,
@@ -173,7 +173,10 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.receivedForwardSecurePacket = true
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.receivedForwardSecurePacket = true
close(h.aeadChanged)
}
return res, protocol.EncryptionForwardSecure, nil
}
if h.receivedForwardSecurePacket {
@@ -428,11 +431,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error {
panic("not needed for cryptoSetupServer")
}
// HandshakeComplete returns true after the first forward secure packet was received form the client.
func (h *cryptoSetupServer) HandshakeComplete() bool {
return h.receivedForwardSecurePacket
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View File

@@ -302,8 +302,11 @@ var _ = Describe("Server Crypto Setup", func() {
err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(aeadChanged).To(Receive())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
})
It("rejects client nonces that have the wrong length", func() {
@@ -334,11 +337,8 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
var encLevel protocol.EncryptionLevel
Expect(aeadChanged).To(Receive(&encLevel))
Expect(encLevel).To(Equal(protocol.EncryptionSecure))
Expect(aeadChanged).To(Receive(&encLevel))
Expect(encLevel).To(Equal(protocol.EncryptionForwardSecure))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
})
It("recognizes inchoate CHLOs missing SCID", func() {
@@ -593,15 +593,11 @@ var _ = Describe("Server Crypto Setup", func() {
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
doCHLO()
enc, seal := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionSecure))
_ = seal(nil, []byte("SHLO"), 0, []byte{})
enc, seal = cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
_ = seal(nil, []byte("foobar"), 0, []byte{})
Expect(cs.HandshakeComplete()).To(BeFalse())
cs.receivedForwardSecurePacket = true
Expect(cs.HandshakeComplete()).To(BeTrue())
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive()) // consume the protocol.EncryptionSecure
Expect(aeadChanged).To(Receive()) // consume the protocol.EncryptionForwardSecure
Expect(aeadChanged).To(BeClosed())
})
})

View File

@@ -9,7 +9,6 @@ type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associated
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
HandshakeComplete() bool
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient

View File

@@ -13,9 +13,8 @@ import (
)
type mockCryptoSetup struct {
divNonce []byte
handshakeComplete bool
encLevelSeal protocol.EncryptionLevel
divNonce []byte
encLevelSeal protocol.EncryptionLevel
}
func (m *mockCryptoSetup) HandleCryptoStream() error { return nil }
@@ -33,7 +32,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
return append(src, bytes.Repeat([]byte{0}, 12)...)
}, nil
}
func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete }
func (m *mockCryptoSetup) DiversificationNonce() []byte {
return m.divNonce
}

View File

@@ -76,7 +76,10 @@ type session struct {
undecryptablePackets []*receivedPacket
receivedTooManyUndecrytablePacketsTime time.Time
aeadChanged chan protocol.EncryptionLevel
// this channel is passed to the CryptoSetup and receives the current encryption level
// it is closed as soon as the handshake is complete
aeadChanged <-chan protocol.EncryptionLevel
handshakeComplete bool
nextAckScheduledTime time.Time
@@ -98,7 +101,7 @@ type session struct {
var _ Session = &session{}
// newSession makes a new session
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) {
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) {
s := &session{
conn: conn,
connectionID: connectionID,
@@ -118,8 +121,10 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
} else {
sourceAddr = []byte(conn.RemoteAddr().String())
}
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged
var err error
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, s.aeadChanged)
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, aeadChanged)
if err != nil {
return nil, err
}
@@ -144,9 +149,11 @@ func newClientSession(conn connection, hostname string, v protocol.VersionNumber
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.setup()
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged
cryptoStream, _ := s.OpenStream()
var err error
s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, s.aeadChanged, negotiatedVersions)
s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, aeadChanged, negotiatedVersions)
if err != nil {
return nil, err
}
@@ -195,6 +202,8 @@ func (s *session) run() error {
}()
var closeErr error
aeadChanged := s.aeadChanged
runLoop:
for {
// Close immediately if requested
@@ -226,12 +235,17 @@ runLoop:
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
putPacketBuffer(p.publicHeader.Raw)
case l := <-s.aeadChanged:
if l == protocol.EncryptionForwardSecure {
s.packer.SetForwardSecure()
case l, ok := <-aeadChanged:
if !ok {
s.handshakeComplete = true
aeadChanged = nil // prevent this case from ever being selected again
} else {
if l == protocol.EncryptionForwardSecure {
s.packer.SetForwardSecure()
}
s.tryDecryptingQueuedPackets()
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure)
}
s.tryDecryptingQueuedPackets()
s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure)
}
if err != nil {
@@ -254,7 +268,7 @@ runLoop:
if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if !s.cryptoSetup.HandshakeComplete() && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake {
if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time."))
}
s.garbageCollectStreams()
@@ -273,7 +287,7 @@ func (s *session) maybeResetTimer() {
if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, lossTime)
}
if !s.cryptoSetup.HandshakeComplete() {
if !s.handshakeComplete {
handshakeDeadline := s.sessionCreationTime.Add(protocol.MaxTimeForCryptoHandshake)
nextDeadline = utils.MinTime(nextDeadline, handshakeDeadline)
}
@@ -298,7 +312,7 @@ func (s *session) maybeResetTimer() {
}
func (s *session) idleTimeout() time.Duration {
if s.cryptoSetup.HandshakeComplete() {
if s.handshakeComplete {
return s.connectionParameters.GetIdleConnectionStateLifetime()
}
return protocol.InitialIdleTimeout
@@ -753,7 +767,7 @@ func (s *session) scheduleSending() {
}
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.cryptoSetup.HandshakeComplete() {
if s.handshakeComplete {
return
}
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {

View File

@@ -125,11 +125,12 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() {
var (
sess *session
clientSess *session
scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager
sess *session
clientSess *session
scfg *handshake.ServerConfig
mconn *mockConnection
cpm *mockConnectionParametersManager
aeadChanged chan<- protocol.EncryptionLevel
)
BeforeEach(func() {
@@ -154,6 +155,9 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
// we need an aeadChanged chan that we can write to
// since type assertions on chans are not possible, we have to extract it from the CryptoSetup
aeadChanged = *(*chan<- protocol.EncryptionLevel)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("aeadChanged").UnsafeAddr()))
cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second}
sess.connectionParameters = cpm
@@ -1156,9 +1160,9 @@ var _ = Describe("Session", func() {
It("tells the packetPacker when forward-secure encryption is used", func() {
go sess.run()
defer sess.Close(nil)
sess.aeadChanged <- protocol.EncryptionSecure
aeadChanged <- protocol.EncryptionSecure
Consistently(func() bool { return sess.packer.isForwardSecure }).Should(BeFalse())
sess.aeadChanged <- protocol.EncryptionForwardSecure
aeadChanged <- protocol.EncryptionForwardSecure
Eventually(func() bool { return sess.packer.isForwardSecure }).Should(BeTrue())
})
@@ -1241,7 +1245,7 @@ var _ = Describe("Session", func() {
})
It("ignores undecryptable packets after the handshake is complete", func() {
sess.cryptoSetup.(*mockCryptoSetup).handshakeComplete = true
close(aeadChanged)
go sess.run()
sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty())
@@ -1270,7 +1274,7 @@ var _ = Describe("Session", func() {
callbackSession = s
}
sess.cryptoChangeCallback = cb
sess.aeadChanged <- protocol.EncryptionSecure
aeadChanged <- protocol.EncryptionSecure
go sess.run()
defer sess.Close(nil)
Eventually(func() bool { return callbackCalled }).Should(BeTrue())
@@ -1287,8 +1291,8 @@ var _ = Describe("Session", func() {
callbackCalledWith = p
}
sess.cryptoChangeCallback = cb
sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: true}
sess.aeadChanged <- protocol.EncryptionForwardSecure
aeadChanged <- protocol.EncryptionForwardSecure
close(aeadChanged)
go sess.run()
defer sess.Close(nil)
Eventually(func() bool { return callbackCalledWith }).Should(BeTrue())
@@ -1324,9 +1328,7 @@ var _ = Describe("Session", func() {
})
It("uses ICSL after handshake", func(done Done) {
// sess.lastNetworkActivityTime = time.Now().Add(-time.Minute)
*(*bool)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true
*(*crypto.AEAD)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("forwardSecureAEAD").UnsafeAddr())) = crypto.NewNullAEAD(protocol.PerspectiveServer, protocol.VersionWhatever)
close(aeadChanged)
cpm.idleTime = 0 * time.Millisecond
sess.packer.connectionParameters = sess.connectionParameters
sess.run() // Would normally not return