diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 709acb66..e19593a8 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -130,10 +130,6 @@ var _ = Describe("Client Crypto Setup", func() { cs.cryptoStream = stream }) - AfterEach(func() { - close(stream.unblockRead) - }) - Context("Reading REJ", func() { var tagMap map[Tag][]byte @@ -158,8 +154,17 @@ var _ = Describe("Client Crypto Setup", func() { stk := []byte("foobar") tagMap[TagSTK] = stk HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) - go cs.HandleCryptoStream() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("saves the proof", func() { @@ -432,26 +437,36 @@ var _ = Describe("Client Crypto Setup", func() { It("closes the aeadChanged when receiving an SHLO", func() { HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + done := make(chan struct{}) go func() { defer GinkgoRecover() err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) Eventually(aeadChanged).Should(BeClosed()) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("passes the transport parameters on the channel", func() { shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba} HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + done := make(chan struct{}) go func() { defer GinkgoRecover() err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() var params TransportParameters Eventually(paramsChan).Should(Receive(¶ms)) Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d))) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("errors if it can't read a connection parameter", func() { @@ -672,18 +687,29 @@ var _ = Describe("Client Crypto Setup", func() { It("tries to escalate before reading a handshake message", func() { Expect(cs.secureAEAD).To(BeNil()) cs.serverVerified = true - go cs.HandleCryptoStream() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) - It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { + It("tries to escalate the crypto after receiving a diversification nonce", func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() - cs.HandleCryptoStream() - Fail("HandleCryptoStream should not have returned") + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) }() cs.diversificationNonce = nil cs.serverVerified = true @@ -693,7 +719,9 @@ var _ = Describe("Client Crypto Setup", func() { Expect(cs.secureAEAD).ToNot(BeNil()) Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) - close(done) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) Context("null encryption", func() { @@ -862,32 +890,51 @@ var _ = Describe("Client Crypto Setup", func() { Context("Diversification Nonces", func() { It("sets a diversification nonce", func() { - go cs.HandleCryptoStream() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) - It("doesn't do anything when called multiple times with the same nonce", func(done Done) { - go cs.HandleCryptoStream() + It("doesn't do anything when called multiple times with the same nonce", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) + close(done) + }() nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) - close(done) + // make the go routine return + stream.close() + Eventually(done).Should(BeClosed()) }) It("rejects a different diversification nonce", func() { - var err error + done := make(chan struct{}) go func() { - err = cs.HandleCryptoStream() + defer GinkgoRecover() + err := cs.HandleCryptoStream() + Expect(err).To(MatchError(errConflictingDiversificationNonces)) + close(done) }() - nonce1 := []byte("foobar") nonce2 := []byte("raboof") cs.SetDiversificationNonce(nonce1) cs.SetDiversificationNonce(nonce2) - Eventually(func() error { return err }).Should(MatchError(errConflictingDiversificationNonces)) + Eventually(done).Should(BeClosed()) }) }) diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 99caded1..344c2383 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "io" "net" "time" @@ -63,35 +64,36 @@ func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte } type mockStream struct { - unblockRead chan struct{} // close this chan to unblock Read + unblockRead chan struct{} dataToRead bytes.Buffer dataWritten bytes.Buffer } +var _ io.ReadWriter = &mockStream{} + +var errMockStreamClosing = errors.New("mock stream closing") + func newMockStream() *mockStream { return &mockStream{unblockRead: make(chan struct{})} } +// call Close to make Read return func (s *mockStream) Read(p []byte) (int, error) { n, _ := s.dataToRead.Read(p) if n == 0 { // block if there's no data <-s.unblockRead + return 0, errMockStreamClosing } return n, nil // never return an EOF } -func (s *mockStream) ReadByte() (byte, error) { - return s.dataToRead.ReadByte() -} - func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) } -func (s *mockStream) Close() error { panic("not implemented") } -func (s *mockStream) Reset(error) { panic("not implemented") } -func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } -func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } +func (s *mockStream) close() { + close(s.unblockRead) +} type mockCookieProtector struct { data []byte @@ -183,10 +185,6 @@ var _ = Describe("Server Crypto Setup", func() { cs.cryptoStream = stream }) - AfterEach(func() { - close(stream.unblockRead) - }) - Context("diversification nonce", func() { BeforeEach(func() { cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl)