forked from quic-go/quic-go
Simplify session closing
This commit is contained in:
73
session.go
73
session.go
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
@@ -78,7 +78,7 @@ type session struct {
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
closeChan chan closeError
|
||||
runClosed chan struct{}
|
||||
closed uint32 // atomic bool
|
||||
closeOnce sync.Once
|
||||
|
||||
// when we receive too many undecryptable packets during the handshake, we send a Public reset
|
||||
// but only after a time of protocol.PublicResetTimeout has passed
|
||||
@@ -290,7 +290,7 @@ runLoop:
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
continue
|
||||
}
|
||||
s.close(err)
|
||||
s.closeLocal(err)
|
||||
continue
|
||||
}
|
||||
// This is a bit unclean, but works properly, since the packet always
|
||||
@@ -319,16 +319,16 @@ runLoop:
|
||||
}
|
||||
|
||||
if err := s.sendPacket(); err != nil {
|
||||
s.close(err)
|
||||
s.closeLocal(err)
|
||||
}
|
||||
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
|
||||
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
}
|
||||
if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
|
||||
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout {
|
||||
s.close(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
|
||||
s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
|
||||
}
|
||||
s.garbageCollectStreams()
|
||||
}
|
||||
@@ -462,7 +462,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
|
||||
case *frames.AckFrame:
|
||||
err = s.handleAckFrame(frame)
|
||||
case *frames.ConnectionCloseFrame:
|
||||
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
s.close(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
case *frames.GoawayFrame:
|
||||
err = errors.New("unimplemented: handling GOAWAY frames")
|
||||
case *frames.StopWaitingFrame:
|
||||
@@ -548,48 +548,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
|
||||
}
|
||||
|
||||
func (s *session) registerClose(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
}
|
||||
func (s *session) close(e error, remoteClose bool) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeChan <- closeError{err: e, remote: remoteClose}
|
||||
})
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
}
|
||||
|
||||
s.closeChan <- closeError{err: e, remote: remoteClose}
|
||||
return nil
|
||||
func (s *session) closeLocal(e error) {
|
||||
s.close(e, false)
|
||||
}
|
||||
|
||||
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *session) Close(e error) error {
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wait for the run loop to finish
|
||||
s.close(e, false)
|
||||
<-s.runClosed
|
||||
return err
|
||||
}
|
||||
|
||||
// close the connection. Use this when called from the run loop
|
||||
func (s *session) close(e error) error {
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleCloseError(closeErr closeError) error {
|
||||
if closeErr.err == nil {
|
||||
closeErr.err = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
var quicErr *qerr.QuicError
|
||||
var ok bool
|
||||
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
|
||||
@@ -602,13 +583,12 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
||||
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
|
||||
}
|
||||
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
|
||||
if closeErr.err == errCloseSessionForNewVersion {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
s.closeStreamsWithError(quicErr)
|
||||
|
||||
// If this is a remote close we're done here
|
||||
if closeErr.remote {
|
||||
return nil
|
||||
@@ -620,13 +600,6 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
||||
return s.sendConnectionClose(quicErr)
|
||||
}
|
||||
|
||||
func (s *session) closeStreamsWithError(err error) {
|
||||
s.streamsMap.Iterate(func(str *stream) (bool, error) {
|
||||
str.Cancel(err)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) sendPacket() error {
|
||||
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
|
||||
for {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -357,9 +356,9 @@ var _ = Describe("Session", func() {
|
||||
p := make([]byte, 4)
|
||||
_, err = str.Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.closeStreamsWithError(testErr)
|
||||
sess.handleCloseError(closeError{err: testErr, remote: true})
|
||||
_, err = str.Read(p)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
|
||||
sess.garbageCollectStreams()
|
||||
str, err = sess.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -372,9 +371,9 @@ var _ = Describe("Session", func() {
|
||||
str, err := sess.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
sess.closeStreamsWithError(testErr)
|
||||
sess.handleCloseError(closeError{err: testErr, remote: true})
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
|
||||
sess.garbageCollectStreams()
|
||||
str, err = sess.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -714,7 +713,7 @@ var _ = Describe("Session", func() {
|
||||
Expect(sess.runClosed).ToNot(BeClosed())
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(func() error { return err }).Should(HaveOccurred())
|
||||
Expect(err).To(MatchError(errCloseSessionForNewVersion))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error())))
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
@@ -760,7 +759,6 @@ var _ = Describe("Session", func() {
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue())
|
||||
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
|
||||
})
|
||||
|
||||
|
||||
@@ -319,8 +319,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
||||
|
||||
func (m *streamsMap) CloseWithError(err error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.closeErr = err
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
m.mutex.Unlock()
|
||||
for _, s := range m.openStreams {
|
||||
m.streams[s].Cancel(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user