diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 9eb4039e..9860daa2 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -29,14 +29,16 @@ type cryptoSetupClient struct { serverConfig *serverConfigClient - stk []byte - sno []byte - nonc []byte - proof []byte + stk []byte + sno []byte + nonc []byte + proof []byte + chloForSignature []byte + lastSentCHLO []byte + certManager crypto.CertManager + + divNonceChan chan struct{} diversificationNonce []byte - chloForSignature []byte - lastSentCHLO []byte - certManager crypto.CertManager clientHelloCounter int serverVerified bool // has the certificate chain and the proof already been verified @@ -83,20 +85,35 @@ func NewCryptoSetupClient( nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), aeadChanged: aeadChanged, negotiatedVersions: negotiatedVersions, + divNonceChan: make(chan struct{}), }, nil } func (h *cryptoSetupClient) HandleCryptoStream() error { + messageChan := make(chan HandshakeMessage) + errorChan := make(chan error) + + go func() { + for { + message, err := ParseHandshakeMessage(h.cryptoStream) + if err != nil { + errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error()) + return + } + messageChan <- message + } + }() + for { err := h.maybeUpgradeCrypto() if err != nil { return err } - // send CHLOs until the forward secure encryption is established h.mutex.RLock() - sendCHLO := h.forwardSecureAEAD == nil + sendCHLO := h.secureAEAD == nil h.mutex.RUnlock() + if sendCHLO { err = h.sendCHLO() if err != nil { @@ -104,9 +121,14 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } } - message, err := ParseHandshakeMessage(h.cryptoStream) - if err != nil { - return qerr.HandshakeFailed + var message HandshakeMessage + select { + case <-h.divNonceChan: + // there's no message to process, but we should try upgrading the crypto again + continue + case message = <-messageChan: + case err = <-errorChan: + return err } utils.Debugf("Got %s", message) @@ -351,9 +373,13 @@ func (h *cryptoSetupClient) DiversificationNonce() []byte { } func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { + h.mutex.Lock() + defer h.mutex.Unlock() + if len(h.diversificationNonce) == 0 { h.diversificationNonce = data - return h.maybeUpgradeCrypto() + h.divNonceChan <- struct{}{} + return nil } if !bytes.Equal(h.diversificationNonce, data) { return errConflictingDiversificationNonces @@ -458,7 +484,6 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { defer h.mutex.Unlock() leafCert := h.certManager.GetLeafCert() - if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) { var err error var nonce []byte diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 4005418f..5a7e936f 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -101,7 +101,7 @@ var _ = Describe("Client Crypto Setup", func() { return &mockAEAD{forwardSecure: forwardSecure, sharedSecret: sharedSecret}, nil } - stream = &mockStream{} + stream = newMockStream() certManager = &mockCertManager{} version := protocol.Version36 aeadChanged = make(chan protocol.EncryptionLevel, 2) @@ -113,6 +113,10 @@ var _ = Describe("Client Crypto Setup", func() { cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } }) + AfterEach(func() { + close(stream.unblockRead) + }) + Context("Reading REJ", func() { var tagMap map[Tag][]byte @@ -127,23 +131,18 @@ var _ = Describe("Client Crypto Setup", func() { }) It("errors on invalid handshake messages", func() { - b := &bytes.Buffer{} - HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(b) - stream.dataToRead.Write(b.Bytes()[:b.Len()-2]) // cut the handshake message + stream.dataToRead.Write([]byte("invalid message")) err := cs.HandleCryptoStream() - // note that if this was a complete handshake message, HandleCryptoStream would fail with a qerr.InvalidCryptoMessageType - Expect(err).To(MatchError(qerr.HandshakeFailed)) + Expect(err).To(HaveOccurred()) + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed)) }) It("passes the message on for parsing, and reads the source address token", func() { stk := []byte("foobar") tagMap[TagSTK] = stk HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) - // this will throw a qerr.HandshakeFailed due to an EOF in WriteHandshakeMessage - // this is because the mockStream doesn't block if there's no data to read - err := cs.HandleCryptoStream() - Expect(err).To(MatchError(qerr.HandshakeFailed)) - Expect(cs.stk).Should(Equal(stk)) + go cs.HandleCryptoStream() + Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) }) It("saves the proof", func() { @@ -637,26 +636,26 @@ var _ = Describe("Client Crypto Setup", func() { It("tries to escalate before reading a handshake message", func() { Expect(cs.secureAEAD).To(BeNil()) cs.serverVerified = true - err := cs.HandleCryptoStream() - // this will throw a qerr.HandshakeFailed due to an EOF in WriteHandshakeMessage - // this is because the mockStream doesn't block if there's no data to read - Expect(err).To(MatchError(qerr.HandshakeFailed)) + go cs.HandleCryptoStream() + Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) }) - It("tries to escalate the crypto after receiving a diversification nonce", func() { + It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { + go cs.HandleCryptoStream() + time.Sleep(50 * time.Millisecond) // wait for the first maybeUpgradeCrypto to finish cs.diversificationNonce = nil cs.serverVerified = true Expect(cs.secureAEAD).To(BeNil()) err := cs.SetDiversificationNonce([]byte("div")) Expect(err).ToNot(HaveOccurred()) + Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) + close(done) }) Context("null encryption", func() { @@ -785,6 +784,11 @@ var _ = Describe("Client Crypto Setup", func() { }) Context("Diversification Nonces", func() { + BeforeEach(func() { + go cs.HandleCryptoStream() + time.Sleep(50 * time.Millisecond) // wait for the first maybeUpdateCrypto to finish + }) + It("sets a diversification nonce", func() { nonce := []byte("foobar") err := cs.SetDiversificationNonce(nonce) diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index d1b1513b..b43b77a6 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -95,12 +95,21 @@ func mockKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID p } type mockStream struct { + unblockRead chan struct{} // close this chan to unblock Read dataToRead bytes.Buffer dataWritten bytes.Buffer } +func newMockStream() *mockStream { + return &mockStream{unblockRead: make(chan struct{})} +} + func (s *mockStream) Read(p []byte) (int, error) { - return s.dataToRead.Read(p) + n, _ := s.dataToRead.Read(p) + if n == 0 { // block if there's no data + <-s.unblockRead + } + return n, nil // never return an EOF } func (s *mockStream) ReadByte() (byte, error) { @@ -168,7 +177,7 @@ var _ = Describe("Server Crypto Setup", func() { expectedInitialNonceLen = 32 expectedFSNonceLen = 64 aeadChanged = make(chan protocol.EncryptionLevel, 2) - stream = &mockStream{} + stream = newMockStream() kex = &mockKEX{} signer = &mockSigner{} scfg, err = NewServerConfig(kex, signer) @@ -190,6 +199,10 @@ var _ = Describe("Server Crypto Setup", func() { cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } }) + AfterEach(func() { + close(stream.unblockRead) + }) + Context("diversification nonce", func() { BeforeEach(func() { cs.version = protocol.Version35 @@ -485,6 +498,7 @@ var _ = Describe("Server Crypto Setup", func() { }) It("errors with invalid message", func() { + stream.dataToRead.Write([]byte("invalid message")) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.HandshakeFailed)) }) diff --git a/session_test.go b/session_test.go index 4becada0..1697c1d2 100644 --- a/session_test.go +++ b/session_test.go @@ -163,7 +163,7 @@ var _ = Describe("Session", func() { sess.connectionParameters = cpm clientSess, err = newClientSession( - nil, + mconn, "hostname", protocol.Version35, 0, @@ -738,11 +738,14 @@ var _ = Describe("Session", func() { }) It("passes the diversification nonce to the cryptoSetup, if it is a client", func() { + go clientSess.run() + time.Sleep(50 * time.Millisecond) hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") err := clientSess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) Expect((*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSess.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))).To(Equal(&hdr.DiversificationNonce)) + Expect(clientSess.Close(nil)).To(Succeed()) }) Context("updating the remote address", func() {