diff --git a/session.go b/session.go index 55bc649b..3e4d3353 100644 --- a/session.go +++ b/session.go @@ -39,6 +39,11 @@ var ( // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that type cryptoChangeCallback func(session Session, isForwardSecure bool) +type closeError struct { + err error + remote bool +} + // A Session is a QUIC session type session struct { connectionID protocol.ConnectionID @@ -67,7 +72,7 @@ type session struct { receivedPackets chan *receivedPacket sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. - closeChan chan error + closeChan chan closeError runClosed chan struct{} closed uint32 // atomic bool @@ -178,11 +183,11 @@ func (s *session) setup() { s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) - s.closeChan = make(chan error, 1) + s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.aeadChanged = make(chan protocol.EncryptionLevel, 2) - s.runClosed = make(chan struct{}, 1) + s.runClosed = make(chan struct{}) s.timer = time.NewTimer(0) s.lastNetworkActivityTime = now @@ -201,7 +206,7 @@ func (s *session) run() error { } }() - var closeErr error + var closeErr closeError aeadChanged := s.aeadChanged runLoop: @@ -215,7 +220,6 @@ runLoop: s.maybeResetTimer() - var err error select { case closeErr = <-s.closeChan: break runLoop @@ -227,10 +231,14 @@ runLoop: // We do all the interesting stuff after the switch statement, so // nothing to see here. case p := <-s.receivedPackets: - err = s.handlePacketImpl(p) - if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { - s.tryQueueingUndecryptablePacket(p) - continue + err := s.handlePacketImpl(p) + if err != nil { + if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { + s.tryQueueingUndecryptablePacket(p) + continue + } + s.close(err) + break runLoop } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. @@ -248,10 +256,6 @@ runLoop: } } - if err != nil { - s.close(err) - } - now := time.Now() if s.sentPacketHandler.GetAlarmTimeout().Before(now) { // This could cause packets to be retransmitted, so check it before trying @@ -274,8 +278,9 @@ runLoop: s.garbageCollectStreams() } - s.runClosed <- struct{}{} - return closeErr + s.handleCloseError(closeErr) + close(s.runClosed) + return closeErr.err } func (s *session) maybeResetTimer() { @@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error { case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: - s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) + s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) case *frames.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") case *frames.StopWaitingFrame: @@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error { return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) } +func (s *session) registerClose(e error, remoteClose bool) error { + // Only close once + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return errSessionAlreadyClosed + } + + if e == nil { + e = qerr.PeerGoingAway + } + + if e == errCloseSessionForNewVersion { + s.streamsMap.CloseWithError(e) + s.closeStreamsWithError(e) + } + + s.closeChan <- closeError{err: e, remote: remoteClose} + return nil +} + // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // It waits until the run loop has stopped before returning func (s *session) Close(e error) error { - err := s.closeImpl(e, false) + err := s.registerClose(e, false) if err == errSessionAlreadyClosed { return nil } @@ -497,55 +521,42 @@ func (s *session) Close(e error) error { // close the connection. Use this when called from the run loop func (s *session) close(e error) error { - err := s.closeImpl(e, false) + err := s.registerClose(e, false) if err == errSessionAlreadyClosed { return nil } return err } -func (s *session) closeImpl(e error, remoteClose bool) error { - // Only close once - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { - return errSessionAlreadyClosed +func (s *session) handleCloseError(closeErr closeError) error { + var quicErr *qerr.QuicError + var ok bool + if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { + quicErr = qerr.ToQuicError(closeErr.err) } - - if e == nil { - e = qerr.PeerGoingAway - } - - defer func() { - s.closeChan <- e - }() - - if e == errCloseSessionForNewVersion { - s.streamsMap.CloseWithError(e) - s.closeStreamsWithError(e) - return nil - } - - quicErr := qerr.ToQuicError(e) - // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { utils.Infof("Closing connection %x", s.connectionID) } else { - utils.Errorf("Closing session with error: %s", e.Error()) + utils.Errorf("Closing session with error: %s", closeErr.err.Error()) + } + + if closeErr.err == errCloseSessionForNewVersion { + return nil } s.streamsMap.CloseWithError(quicErr) s.closeStreamsWithError(quicErr) - if remoteClose { - // If this is a remote close we're done here + // If this is a remote close we're done here + if closeErr.remote { return nil } if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment { return s.sendPublicReset(s.lastRcvdPacketNumber) } - s.sendConnectionClose(quicErr) - return nil + return s.sendConnectionClose(quicErr) } func (s *session) closeStreamsWithError(err error) { diff --git a/session_test.go b/session_test.go index 94775913..4becada0 100644 --- a/session_test.go +++ b/session_test.go @@ -582,12 +582,15 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) }) - It("handles CONNECTION_CLOSE frames", func() { + It("handles CONNECTION_CLOSE frames", func(done Done) { + go sess.run() str, _ := sess.GetOrOpenStream(5) err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) Expect(err).NotTo(HaveOccurred()) + Eventually(sess.runClosed).Should(BeClosed()) _, err = str.Read([]byte{0}) Expect(err).To(MatchError(qerr.Error(42, "foobar"))) + close(done) }) Context("accepting streams", func() { @@ -621,16 +624,17 @@ var _ = Describe("Session", func() { }) It("stops accepting when the session is closed after version negotiation", func() { - testErr := errCloseSessionForNewVersion var err error go func() { _, err = sess.AcceptStream() }() go sess.run() Consistently(func() error { return err }).ShouldNot(HaveOccurred()) - sess.Close(testErr) + Expect(sess.runClosed).ToNot(BeClosed()) + sess.Close(errCloseSessionForNewVersion) Eventually(func() error { return err }).Should(HaveOccurred()) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(errCloseSessionForNewVersion)) + Eventually(sess.runClosed).Should(BeClosed()) }) }) @@ -646,7 +650,7 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) - Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).To(BeClosed()) }) It("only closes once", func() { @@ -654,7 +658,7 @@ var _ = Describe("Session", func() { sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) - Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).To(BeClosed()) }) It("closes streams with proper error", func() { @@ -669,7 +673,7 @@ var _ = Describe("Session", func() { n, err = s.Write([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) - Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).To(BeClosed()) }) It("closes the session in order to replace it with another QUIC version", func() { @@ -683,7 +687,7 @@ var _ = Describe("Session", func() { sess.Close(handshake.ErrHOLExperiment) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset - Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).To(BeClosed()) }) }) @@ -1230,7 +1234,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() // wake up the run loop Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST")))) - Eventually(sess.runClosed).Should(Receive()) + Eventually(sess.runClosed).Should(BeClosed()) }) It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() { @@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() { go sess.run() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) - sess.closeImpl(nil, true) - Eventually(sess.runClosed).Should(Receive()) + Expect(sess.Close(nil)).To(Succeed()) }) It("unqueues undecryptable packets for later decryption", func() { @@ -1305,7 +1308,7 @@ var _ = Describe("Session", func() { sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(sess.runClosed).To(Receive()) + Expect(sess.runClosed).To(BeClosed()) close(done) }) @@ -1313,7 +1316,7 @@ var _ = Describe("Session", func() { sess.sessionCreationTime = time.Now().Add(-time.Hour) sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) - Expect(sess.runClosed).To(Receive()) + Expect(sess.runClosed).To(BeClosed()) close(done) }) @@ -1323,7 +1326,7 @@ var _ = Describe("Session", func() { sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(sess.runClosed).To(Receive()) + Expect(sess.runClosed).To(BeClosed()) close(done) }) @@ -1333,7 +1336,7 @@ var _ = Describe("Session", func() { sess.packer.connectionParameters = sess.connectionParameters sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) - Expect(sess.runClosed).To(Receive()) + Expect(sess.runClosed).To(BeClosed()) close(done) }) })