diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 1bac4b92c..527ee9bad 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -62,6 +62,7 @@ func (s *mockSession) LocalAddr() net.Addr { func (s *mockSession) RemoteAddr() net.Addr { return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} } +func (s *mockSession) WaitUntilClosed() { panic("not implemented") } var _ = Describe("H2 server", func() { var ( diff --git a/interface.go b/interface.go index 20ee6cf02..9cf52d613 100644 --- a/interface.go +++ b/interface.go @@ -37,6 +37,9 @@ type Session interface { RemoteAddr() net.Addr // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. Close(error) error + // WaitUntilClosed() blocks until the session is closed. + // Warning: This API should not be considered stable and might change soon. + WaitUntilClosed() } // A NonFWSession is a QUIC connection between two peers half-way through the handshake. diff --git a/server_test.go b/server_test.go index acb8dece2..6810b7eba 100644 --- a/server_test.go +++ b/server_test.go @@ -39,6 +39,9 @@ func (s *mockSession) run() error { func (s *mockSession) WaitUntilHandshakeComplete() error { return <-s.handshakeComplete } +func (*mockSession) WaitUntilClosed() { + panic("not implemented") +} func (s *mockSession) Close(e error) error { if s.closed { return nil diff --git a/session.go b/session.go index bce7e67a8..bc9534fda 100644 --- a/session.go +++ b/session.go @@ -76,6 +76,8 @@ type session struct { sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. closeChan chan closeError + // runClosed is closed once the run loop exits + // it is used to block Close() and WaitUntilClosed() runClosed chan struct{} closeOnce sync.Once @@ -323,6 +325,10 @@ runLoop: return closeErr.err } +func (s *session) WaitUntilClosed() { + <-s.runClosed +} + func (s *session) maybeResetTimer() { deadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) diff --git a/session_test.go b/session_test.go index 5ba27cc6e..1e14f6224 100644 --- a/session_test.go +++ b/session_test.go @@ -772,6 +772,17 @@ var _ = Describe("Session", func() { Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset Expect(sess.runClosed).To(BeClosed()) }) + + It("unblocks WaitUntilClosed when the run loop exists", func() { + returned := make(chan struct{}) + go func() { + sess.WaitUntilClosed() + close(returned) + }() + Consistently(returned).ShouldNot(BeClosed()) + sess.Close(nil) + Eventually(returned).Should(BeClosed()) + }) }) Context("receiving packets", func() {