From 156c23f2b7379795056a1ba8bc8a96db40d12377 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 19 Dec 2016 20:35:21 +0700 Subject: [PATCH] wait until the run loop has stopped before returning Session.Close() fixes #371 --- session.go | 36 ++++++++++++++++++++++++++++-------- session_test.go | 8 ++++++-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/session.go b/session.go index a4b743fc..909bcdcd 100644 --- a/session.go +++ b/session.go @@ -70,6 +70,7 @@ type Session struct { // closeChan is used to notify the run loop that it should terminate. // If the value is not nil, the error is sent as a CONNECTION_CLOSE. closeChan chan *qerr.QuicError + runClosed chan struct{} closed uint32 // atomic bool undecryptablePackets []*receivedPacket @@ -124,6 +125,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sendingScheduled: make(chan struct{}, 1), undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets), aeadChanged: make(chan struct{}, 1), + runClosed: make(chan struct{}, 1), // this channel will receive once the run loop has been stopped timer: time.NewTimer(0), lastNetworkActivityTime: now, @@ -151,10 +153,11 @@ func (s *Session) run() { // Start the crypto stream handler go func() { if err := s.cryptoSetup.HandleCryptoStream(); err != nil { - s.Close(err) + s.close(err) } }() +runLoop: for { // Close immediately if requested select { @@ -162,7 +165,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop default: } @@ -174,7 +177,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop case <-s.timer.C: s.timerRead = true // We do all the interesting stuff after the switch statement, so @@ -199,20 +202,23 @@ func (s *Session) run() { } if err != nil { - s.Close(err) + s.close(err) } if err := s.sendPacket(); err != nil { - s.Close(err) + s.close(err) } if time.Now().Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } if !s.cryptoSetup.HandshakeComplete() && time.Now().Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) } s.garbageCollectStreams() } + + s.closeCallback(s.connectionID) + s.runClosed <- struct{}{} } func (s *Session) maybeResetTimer() { @@ -407,7 +413,22 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error { } // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. +// It waits until the run loop has stopped before returning func (s *Session) Close(e error) error { + err := s.closeImpl(e, false) + + if atomic.LoadUint32(&s.closed) == 1 { + return err + } + + select { + case <-s.runClosed: + return err + } +} + +// close the connection. Use this when called from the run loop +func (s *Session) close(e error) error { return s.closeImpl(e, false) } @@ -431,7 +452,6 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { } s.closeStreamsWithError(quicErr) - s.closeCallback(s.connectionID) if remoteClose { // If this is a remote close we don't need to send a CONNECTION_CLOSE diff --git a/session_test.go b/session_test.go index d7cd541e..bea09107 100644 --- a/session_test.go +++ b/session_test.go @@ -424,10 +424,10 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { session.Close(nil) - Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) + Expect(closeCallbackCalled).To(BeTrue()) }) It("only closes once", func() { @@ -442,8 +442,8 @@ var _ = Describe("Session", func() { s, err := session.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) session.Close(testErr) - Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) + Expect(closeCallbackCalled).To(BeTrue()) n, err := s.Read([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) @@ -757,6 +757,7 @@ var _ = Describe("Session", func() { session.lastNetworkActivityTime = time.Now().Add(-time.Hour) session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) close(done) }) @@ -764,6 +765,7 @@ var _ = Describe("Session", func() { session.sessionCreationTime = time.Now().Add(-time.Hour) session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) + Expect(closeCallbackCalled).To(BeTrue()) close(done) }) @@ -773,6 +775,7 @@ var _ = Describe("Session", func() { session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) close(done) }) @@ -784,6 +787,7 @@ var _ = Describe("Session", func() { session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) close(done) }) })