diff --git a/session.go b/session.go index a4b743fc..af61cd96 100644 --- a/session.go +++ b/session.go @@ -32,6 +32,7 @@ type receivedPacket struct { var ( errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") + errSessionAlreadyClosed = errors.New("Cannot close Session. It was already closed before.") ) // StreamCallback gets a stream frame and returns a reply frame @@ -70,6 +71,7 @@ type Session struct { // closeChan is used to notify the run loop that it should terminate. // If the value is not nil, the error is sent as a CONNECTION_CLOSE. closeChan chan *qerr.QuicError + runClosed chan struct{} closed uint32 // atomic bool undecryptablePackets []*receivedPacket @@ -124,6 +126,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sendingScheduled: make(chan struct{}, 1), undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets), aeadChanged: make(chan struct{}, 1), + runClosed: make(chan struct{}, 1), // this channel will receive once the run loop has been stopped timer: time.NewTimer(0), lastNetworkActivityTime: now, @@ -155,6 +158,7 @@ func (s *Session) run() { } }() +runLoop: for { // Close immediately if requested select { @@ -162,7 +166,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop default: } @@ -174,7 +178,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop case <-s.timer.C: s.timerRead = true // We do all the interesting stuff after the switch statement, so @@ -199,20 +203,23 @@ func (s *Session) run() { } if err != nil { - s.Close(err) + s.close(err) } if err := s.sendPacket(); err != nil { - s.Close(err) + s.close(err) } if time.Now().Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } if !s.cryptoSetup.HandshakeComplete() && time.Now().Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) } s.garbageCollectStreams() } + + s.closeCallback(s.connectionID) + s.runClosed <- struct{}{} } func (s *Session) maybeResetTimer() { @@ -407,14 +414,31 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error { } // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. +// It waits until the run loop has stopped before returning func (s *Session) Close(e error) error { - return s.closeImpl(e, false) + err := s.closeImpl(e, false) + if err == errSessionAlreadyClosed { + return nil + } + + // wait for the run loop to finish + <-s.runClosed + return err +} + +// close the connection. Use this when called from the run loop +func (s *Session) close(e error) error { + err := s.closeImpl(e, false) + if err == errSessionAlreadyClosed { + return nil + } + return err } func (s *Session) closeImpl(e error, remoteClose bool) error { // Only close once if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { - return nil + return errSessionAlreadyClosed } if e == nil { @@ -431,7 +455,6 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { } s.closeStreamsWithError(quicErr) - s.closeCallback(s.connectionID) if remoteClose { // If this is a remote close we don't need to send a CONNECTION_CLOSE @@ -651,7 +674,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) { } utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets { - s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) + s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } s.undecryptablePackets = append(s.undecryptablePackets, p) } diff --git a/session_test.go b/session_test.go index d7cd541e..ae41c2be 100644 --- a/session_test.go +++ b/session_test.go @@ -424,10 +424,11 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { session.Close(nil) - Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) + Expect(closeCallbackCalled).To(BeTrue()) + Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) It("only closes once", func() { @@ -435,6 +436,7 @@ var _ = Describe("Session", func() { session.Close(nil) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(conn.written).To(HaveLen(1)) + Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) It("closes streams with proper error", func() { @@ -442,14 +444,15 @@ var _ = Describe("Session", func() { s, err := session.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) session.Close(testErr) - Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) + Expect(closeCallbackCalled).To(BeTrue()) n, err := s.Read([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) n, err = s.Write([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) + Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() }) }) @@ -727,6 +730,7 @@ var _ = Describe("Session", func() { Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST")))) + Expect(session.runClosed).To(Receive()) }) It("ignores undecryptable packets after the handshake is complete", func() { @@ -740,6 +744,7 @@ var _ = Describe("Session", func() { go session.run() Consistently(session.undecryptablePackets).Should(HaveLen(0)) session.closeImpl(nil, true) + Eventually(session.runClosed).Should(Receive()) }) It("unqueues undecryptable packets for later decryption", func() { @@ -757,6 +762,8 @@ var _ = Describe("Session", func() { session.lastNetworkActivityTime = time.Now().Add(-time.Hour) session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) + Expect(session.runClosed).To(Receive()) close(done) }) @@ -764,6 +771,8 @@ var _ = Describe("Session", func() { session.sessionCreationTime = time.Now().Add(-time.Hour) session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) + Expect(closeCallbackCalled).To(BeTrue()) + Expect(session.runClosed).To(Receive()) close(done) }) @@ -773,6 +782,8 @@ var _ = Describe("Session", func() { session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) + Expect(session.runClosed).To(Receive()) close(done) }) @@ -784,6 +795,8 @@ var _ = Describe("Session", func() { session.packer.connectionParameters = session.connectionParameters session.run() // Would normally not return Expect(conn.written[0]).To(ContainSubstring("No recent network activity.")) + Expect(closeCallbackCalled).To(BeTrue()) + Expect(session.runClosed).To(Receive()) close(done) }) })