simplify the CryptoSetup.SetDiversificationNonce interface

Pass the diversification nonce via a channel instead of setting it
directly. That way there is no need to protect the diversificationNonce
member by a mutex. Also prevents a possible deadlock that occurred when
SetDiversificationNonce was called before maybeUpgradeCrypto returned.
This commit is contained in:
Marten Seemann
2017-05-11 10:15:47 +08:00
parent 8c4fe4ff80
commit adc4ef464a
6 changed files with 39 additions and 44 deletions

View File

@@ -37,7 +37,7 @@ type cryptoSetupClient struct {
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan struct{}
divNonceChan chan []byte
diversificationNonce []byte
clientHelloCounter int
@@ -85,7 +85,7 @@ func NewCryptoSetupClient(
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
aeadChanged: aeadChanged,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan struct{}),
divNonceChan: make(chan []byte),
}, nil
}
@@ -123,7 +123,11 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
var message HandshakeMessage
select {
case <-h.divNonceChan:
case divNonce := <-h.divNonceChan:
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
h.diversificationNonce = divNonce
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
@@ -372,19 +376,8 @@ func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient")
}
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
if len(h.diversificationNonce) == 0 {
h.diversificationNonce = data
h.divNonceChan <- struct{}{}
return nil
}
if !bytes.Equal(h.diversificationNonce, data) {
return errConflictingDiversificationNonces
}
return nil
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) sendCHLO() error {

View File

@@ -644,13 +644,15 @@ var _ = Describe("Client Crypto Setup", func() {
})
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
go cs.HandleCryptoStream()
time.Sleep(50 * time.Millisecond) // wait for the first maybeUpgradeCrypto to finish
go func() {
defer GinkgoRecover()
cs.HandleCryptoStream()
Fail("HandleCryptoStream should not have returned")
}()
cs.diversificationNonce = nil
cs.serverVerified = true
Expect(cs.secureAEAD).To(BeNil())
err := cs.SetDiversificationNonce([]byte("div"))
Expect(err).ToNot(HaveOccurred())
cs.SetDiversificationNonce([]byte("div"))
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive())
@@ -784,34 +786,33 @@ var _ = Describe("Client Crypto Setup", func() {
})
Context("Diversification Nonces", func() {
BeforeEach(func() {
go cs.HandleCryptoStream()
time.Sleep(50 * time.Millisecond) // wait for the first maybeUpdateCrypto to finish
})
It("sets a diversification nonce", func() {
go cs.HandleCryptoStream()
nonce := []byte("foobar")
err := cs.SetDiversificationNonce(nonce)
Expect(err).ToNot(HaveOccurred())
Expect(cs.diversificationNonce).To(Equal(nonce))
cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
})
It("doesn't do anything when called multiple times with the same nonce", func() {
It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
go cs.HandleCryptoStream()
nonce := []byte("foobar")
err := cs.SetDiversificationNonce(nonce)
Expect(err).ToNot(HaveOccurred())
err = cs.SetDiversificationNonce(nonce)
Expect(err).ToNot(HaveOccurred())
Expect(cs.diversificationNonce).To(Equal(nonce))
cs.SetDiversificationNonce(nonce)
cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
close(done)
})
It("rejects a different diversification nonce", func() {
var err error
go func() {
err = cs.HandleCryptoStream()
}()
nonce1 := []byte("foobar")
nonce2 := []byte("raboof")
err := cs.SetDiversificationNonce(nonce1)
Expect(err).ToNot(HaveOccurred())
err = cs.SetDiversificationNonce(nonce2)
Expect(err).To(MatchError(errConflictingDiversificationNonces))
cs.SetDiversificationNonce(nonce1)
cs.SetDiversificationNonce(nonce2)
Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces))
})
})

View File

@@ -435,7 +435,7 @@ func (h *cryptoSetupServer) DiversificationNonce() []byte {
return h.diversificationNonce
}
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error {
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}

View File

@@ -10,8 +10,8 @@ type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)

View File

@@ -35,7 +35,7 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
func (m *mockCryptoSetup) DiversificationNonce() []byte {
return m.divNonce
}
func (m *mockCryptoSetup) SetDiversificationNonce([]byte) error { panic("not implemented") }
func (m *mockCryptoSetup) SetDiversificationNonce([]byte) { panic("not implemented") }
var _ handshake.CryptoSetup = &mockCryptoSetup{}

View File

@@ -739,12 +739,13 @@ var _ = Describe("Session", func() {
It("passes the diversification nonce to the cryptoSetup, if it is a client", func() {
go clientSess.run()
time.Sleep(50 * time.Millisecond)
hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar")
err := clientSess.handlePacketImpl(&receivedPacket{publicHeader: hdr})
Expect(err).ToNot(HaveOccurred())
Expect((*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSess.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))).To(Equal(&hdr.DiversificationNonce))
Eventually(func() []byte {
return *(*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSess.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))
}).Should(Equal(hdr.DiversificationNonce))
Expect(clientSess.Close(nil)).To(Succeed())
})