diff --git a/session.go b/session.go index 081cb7ca..3e4d3353 100644 --- a/session.go +++ b/session.go @@ -39,6 +39,11 @@ var ( // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that type cryptoChangeCallback func(session Session, isForwardSecure bool) +type closeError struct { + err error + remote bool +} + // A Session is a QUIC session type session struct { connectionID protocol.ConnectionID @@ -67,7 +72,7 @@ type session struct { receivedPackets chan *receivedPacket sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. - closeChan chan error + closeChan chan closeError runClosed chan struct{} closed uint32 // atomic bool @@ -178,7 +183,7 @@ func (s *session) setup() { s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) - s.closeChan = make(chan error, 1) + s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.aeadChanged = make(chan protocol.EncryptionLevel, 2) @@ -201,7 +206,7 @@ func (s *session) run() error { } }() - var closeErr error + var closeErr closeError aeadChanged := s.aeadChanged runLoop: @@ -215,7 +220,6 @@ runLoop: s.maybeResetTimer() - var err error select { case closeErr = <-s.closeChan: break runLoop @@ -227,10 +231,14 @@ runLoop: // We do all the interesting stuff after the switch statement, so // nothing to see here. case p := <-s.receivedPackets: - err = s.handlePacketImpl(p) - if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { - s.tryQueueingUndecryptablePacket(p) - continue + err := s.handlePacketImpl(p) + if err != nil { + if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { + s.tryQueueingUndecryptablePacket(p) + continue + } + s.close(err) + break runLoop } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. @@ -248,10 +256,6 @@ runLoop: } } - if err != nil { - s.close(err) - } - now := time.Now() if s.sentPacketHandler.GetAlarmTimeout().Before(now) { // This could cause packets to be retransmitted, so check it before trying @@ -274,8 +278,9 @@ runLoop: s.garbageCollectStreams() } + s.handleCloseError(closeErr) close(s.runClosed) - return closeErr + return closeErr.err } func (s *session) maybeResetTimer() { @@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error { case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: - s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) + s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) case *frames.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") case *frames.StopWaitingFrame: @@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error { return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) } +func (s *session) registerClose(e error, remoteClose bool) error { + // Only close once + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return errSessionAlreadyClosed + } + + if e == nil { + e = qerr.PeerGoingAway + } + + if e == errCloseSessionForNewVersion { + s.streamsMap.CloseWithError(e) + s.closeStreamsWithError(e) + } + + s.closeChan <- closeError{err: e, remote: remoteClose} + return nil +} + // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // It waits until the run loop has stopped before returning func (s *session) Close(e error) error { - err := s.closeImpl(e, false) + err := s.registerClose(e, false) if err == errSessionAlreadyClosed { return nil } @@ -497,55 +521,42 @@ func (s *session) Close(e error) error { // close the connection. Use this when called from the run loop func (s *session) close(e error) error { - err := s.closeImpl(e, false) + err := s.registerClose(e, false) if err == errSessionAlreadyClosed { return nil } return err } -func (s *session) closeImpl(e error, remoteClose bool) error { - // Only close once - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { - return errSessionAlreadyClosed +func (s *session) handleCloseError(closeErr closeError) error { + var quicErr *qerr.QuicError + var ok bool + if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { + quicErr = qerr.ToQuicError(closeErr.err) } - - if e == nil { - e = qerr.PeerGoingAway - } - - defer func() { - s.closeChan <- e - }() - - if e == errCloseSessionForNewVersion { - s.streamsMap.CloseWithError(e) - s.closeStreamsWithError(e) - return nil - } - - quicErr := qerr.ToQuicError(e) - // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { utils.Infof("Closing connection %x", s.connectionID) } else { - utils.Errorf("Closing session with error: %s", e.Error()) + utils.Errorf("Closing session with error: %s", closeErr.err.Error()) + } + + if closeErr.err == errCloseSessionForNewVersion { + return nil } s.streamsMap.CloseWithError(quicErr) s.closeStreamsWithError(quicErr) - if remoteClose { - // If this is a remote close we're done here + // If this is a remote close we're done here + if closeErr.remote { return nil } if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment { return s.sendPublicReset(s.lastRcvdPacketNumber) } - s.sendConnectionClose(quicErr) - return nil + return s.sendConnectionClose(quicErr) } func (s *session) closeStreamsWithError(err error) { diff --git a/session_test.go b/session_test.go index c70365b8..4becada0 100644 --- a/session_test.go +++ b/session_test.go @@ -582,12 +582,15 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) }) - It("handles CONNECTION_CLOSE frames", func() { + It("handles CONNECTION_CLOSE frames", func(done Done) { + go sess.run() str, _ := sess.GetOrOpenStream(5) err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) Expect(err).NotTo(HaveOccurred()) + Eventually(sess.runClosed).Should(BeClosed()) _, err = str.Read([]byte{0}) Expect(err).To(MatchError(qerr.Error(42, "foobar"))) + close(done) }) Context("accepting streams", func() { @@ -621,16 +624,17 @@ var _ = Describe("Session", func() { }) It("stops accepting when the session is closed after version negotiation", func() { - testErr := errCloseSessionForNewVersion var err error go func() { _, err = sess.AcceptStream() }() go sess.run() Consistently(func() error { return err }).ShouldNot(HaveOccurred()) - sess.Close(testErr) + Expect(sess.runClosed).ToNot(BeClosed()) + sess.Close(errCloseSessionForNewVersion) Eventually(func() error { return err }).Should(HaveOccurred()) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(errCloseSessionForNewVersion)) + Eventually(sess.runClosed).Should(BeClosed()) }) }) @@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() { go sess.run() sendUndecryptablePackets() Consistently(sess.undecryptablePackets).Should(BeEmpty()) - sess.closeImpl(nil, true) - Eventually(sess.runClosed).Should(BeClosed()) + Expect(sess.Close(nil)).To(Succeed()) }) It("unqueues undecryptable packets for later decryption", func() {