diff --git a/interface.go b/interface.go index b2fa3465c..8f4dda5a4 100644 --- a/interface.go +++ b/interface.go @@ -38,6 +38,13 @@ type Session interface { Close(error) error } +// A NonFWSession is a QUIC connection between two peers half-way through the handshake. +// The communication is encrypted, but not yet forward secure. +type NonFWSession interface { + Session + WaitUntilHandshakeComplete() error +} + // Config contains all configuration data needed for a QUIC server or client. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. type Config struct { diff --git a/packet_packer_test.go b/packet_packer_test.go index 4918d2745..7b347aabb 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -13,12 +13,14 @@ import ( ) type mockCryptoSetup struct { + handleErr error divNonce []byte encLevelSeal protocol.EncryptionLevel } -func (m *mockCryptoSetup) HandleCryptoStream() error { return nil } - +func (m *mockCryptoSetup) HandleCryptoStream() error { + return m.handleErr +} func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { return nil, protocol.EncryptionUnspecified, nil } diff --git a/session.go b/session.go index 3fbc9179c..4662e9079 100644 --- a/session.go +++ b/session.go @@ -83,8 +83,10 @@ type session struct { // this channel is passed to the CryptoSetup and receives the current encryption level // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel - handshakeComplete bool + aeadChanged <-chan protocol.EncryptionLevel + handshakeComplete bool + handshakeChan chan struct{} // will be closed as soon as the handshake completes + handshakeErrorChan chan error nextAckScheduledTime time.Time @@ -215,6 +217,8 @@ func (s *session) setup() { s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.aeadChanged = make(chan protocol.EncryptionLevel, 2) s.runClosed = make(chan struct{}) + s.handshakeChan = make(chan struct{}) + s.handshakeErrorChan = make(chan error, 1) s.timer = time.NewTimer(0) s.lastNetworkActivityTime = now @@ -274,6 +278,7 @@ runLoop: if !ok { s.handshakeComplete = true aeadChanged = nil // prevent this case from ever being selected again + close(s.handshakeChan) } else { if l == protocol.EncryptionForwardSecure { s.packer.SetForwardSecure() @@ -305,6 +310,7 @@ runLoop: s.garbageCollectStreams() } + s.handshakeErrorChan <- closeErr.err s.handleCloseError(closeErr) close(s.runClosed) return closeErr.err @@ -751,6 +757,15 @@ func (s *session) OpenStreamSync() (Stream, error) { return s.streamsMap.OpenStreamSync() } +func (s *session) WaitUntilHandshakeComplete() error { + select { + case <-s.handshakeChan: + return nil + case err := <-s.handshakeErrorChan: + return err + } +} + func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ StreamID: id, diff --git a/session_test.go b/session_test.go index 7f4ec81f5..97bd04a3e 100644 --- a/session_test.go +++ b/session_test.go @@ -593,6 +593,59 @@ var _ = Describe("Session", func() { close(done) }) + Context("waiting until the handshake completes", func() { + It("waits until the handshake is complete", func(done Done) { + go sess.run() + + var waitReturned bool + go func() { + defer GinkgoRecover() + err := sess.WaitUntilHandshakeComplete() + Expect(err).ToNot(HaveOccurred()) + waitReturned = true + }() + aeadChanged <- protocol.EncryptionForwardSecure + Consistently(func() bool { return waitReturned }).Should(BeFalse()) + close(aeadChanged) + Eventually(func() bool { return waitReturned }).Should(BeTrue()) + Expect(sess.Close(nil)).To(Succeed()) + close(done) + }) + + It("errors if the handshake fails", func(done Done) { + testErr := errors.New("crypto error") + sess.cryptoSetup = &mockCryptoSetup{handleErr: testErr} + go sess.run() + err := sess.WaitUntilHandshakeComplete() + Expect(err).To(MatchError(testErr)) + close(done) + }, 0.5) + + It("returns when Close is called", func(done Done) { + testErr := errors.New("close error") + go sess.run() + var waitReturned bool + go func() { + defer GinkgoRecover() + err := sess.WaitUntilHandshakeComplete() + Expect(err).To(MatchError(testErr)) + waitReturned = true + }() + sess.Close(testErr) + Eventually(func() bool { return waitReturned }).Should(BeTrue()) + close(done) + }) + + It("doesn't wait if the handshake is already completed", func(done Done) { + go sess.run() + close(aeadChanged) + err := sess.WaitUntilHandshakeComplete() + Expect(err).ToNot(HaveOccurred()) + Expect(sess.Close(nil)).To(Succeed()) + close(done) + }) + }) + Context("accepting streams", func() { It("waits for new streams", func() { var str Stream