From f2959aa74a4564742b9750a39448d4388c04c987 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Wed, 7 Jun 2017 09:37:43 +0200 Subject: [PATCH] Simplify session closing --- session.go | 73 ++++++++++++++++--------------------------------- session_test.go | 12 ++++---- streams_map.go | 5 +++- 3 files changed, 32 insertions(+), 58 deletions(-) diff --git a/session.go b/session.go index 564ce6f53..45e243451 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" "net" - "sync/atomic" + "sync" "time" "github.com/lucas-clemente/quic-go/ackhandler" @@ -78,7 +78,7 @@ type session struct { // closeChan is used to notify the run loop that it should terminate. closeChan chan closeError runClosed chan struct{} - closed uint32 // atomic bool + closeOnce sync.Once // when we receive too many undecryptable packets during the handshake, we send a Public reset // but only after a time of protocol.PublicResetTimeout has passed @@ -290,7 +290,7 @@ runLoop: s.tryQueueingUndecryptablePacket(p) continue } - s.close(err) + s.closeLocal(err) continue } // This is a bit unclean, but works properly, since the packet always @@ -319,16 +319,16 @@ runLoop: } if err := s.sendPacket(); err != nil { - s.close(err) + s.closeLocal(err) } if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { - s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) + s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout { - s.close(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time.")) + s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time.")) } s.garbageCollectStreams() } @@ -462,7 +462,7 @@ func (s *session) handleFrames(fs []frames.Frame) error { case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: - s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) + s.close(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) case *frames.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") case *frames.StopWaitingFrame: @@ -548,48 +548,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error { return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) } -func (s *session) registerClose(e error, remoteClose bool) error { - // Only close once - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { - return errSessionAlreadyClosed - } +func (s *session) close(e error, remoteClose bool) { + s.closeOnce.Do(func() { + s.closeChan <- closeError{err: e, remote: remoteClose} + }) +} - if e == nil { - e = qerr.PeerGoingAway - } - - if e == errCloseSessionForNewVersion { - s.streamsMap.CloseWithError(e) - s.closeStreamsWithError(e) - } - - s.closeChan <- closeError{err: e, remote: remoteClose} - return nil +func (s *session) closeLocal(e error) { + s.close(e, false) } // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // It waits until the run loop has stopped before returning func (s *session) Close(e error) error { - err := s.registerClose(e, false) - if err == errSessionAlreadyClosed { - return nil - } - - // wait for the run loop to finish + s.close(e, false) <-s.runClosed - return err -} - -// close the connection. Use this when called from the run loop -func (s *session) close(e error) error { - err := s.registerClose(e, false) - if err == errSessionAlreadyClosed { - return nil - } - return err + return nil } func (s *session) handleCloseError(closeErr closeError) error { + if closeErr.err == nil { + closeErr.err = qerr.PeerGoingAway + } + var quicErr *qerr.QuicError var ok bool if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { @@ -602,13 +583,12 @@ func (s *session) handleCloseError(closeErr closeError) error { utils.Errorf("Closing session with error: %s", closeErr.err.Error()) } + s.streamsMap.CloseWithError(quicErr) + if closeErr.err == errCloseSessionForNewVersion { return nil } - s.streamsMap.CloseWithError(quicErr) - s.closeStreamsWithError(quicErr) - // If this is a remote close we're done here if closeErr.remote { return nil @@ -620,13 +600,6 @@ func (s *session) handleCloseError(closeErr closeError) error { return s.sendConnectionClose(quicErr) } -func (s *session) closeStreamsWithError(err error) { - s.streamsMap.Iterate(func(str *stream) (bool, error) { - str.Cancel(err) - return true, nil - }) -} - func (s *session) sendPacket() error { // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { diff --git a/session_test.go b/session_test.go index 7265504b0..228aef3af 100644 --- a/session_test.go +++ b/session_test.go @@ -8,7 +8,6 @@ import ( "net" "runtime/pprof" "strings" - "sync/atomic" "time" . "github.com/onsi/ginkgo" @@ -357,9 +356,9 @@ var _ = Describe("Session", func() { p := make([]byte, 4) _, err = str.Read(p) Expect(err).ToNot(HaveOccurred()) - sess.closeStreamsWithError(testErr) + sess.handleCloseError(closeError{err: testErr, remote: true}) _, err = str.Read(p) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) sess.garbageCollectStreams() str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) @@ -372,9 +371,9 @@ var _ = Describe("Session", func() { str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) - sess.closeStreamsWithError(testErr) + sess.handleCloseError(closeError{err: testErr, remote: true}) _, err = str.Read([]byte{0}) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error()))) sess.garbageCollectStreams() str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) @@ -714,7 +713,7 @@ var _ = Describe("Session", func() { Expect(sess.runClosed).ToNot(BeClosed()) sess.Close(errCloseSessionForNewVersion) Eventually(func() error { return err }).Should(HaveOccurred()) - Expect(err).To(MatchError(errCloseSessionForNewVersion)) + Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error()))) Eventually(sess.runClosed).Should(BeClosed()) }) }) @@ -760,7 +759,6 @@ var _ = Describe("Session", func() { It("closes the session in order to replace it with another QUIC version", func() { sess.Close(errCloseSessionForNewVersion) Eventually(areSessionsRunning).Should(BeFalse()) - Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent }) diff --git a/streams_map.go b/streams_map.go index d29ac80b2..31cfe5a84 100644 --- a/streams_map.go +++ b/streams_map.go @@ -319,8 +319,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { func (m *streamsMap) CloseWithError(err error) { m.mutex.Lock() + defer m.mutex.Unlock() m.closeErr = err m.nextStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast() - m.mutex.Unlock() + for _, s := range m.openStreams { + m.streams[s].Cancel(err) + } }