Simplify session closing

This commit is contained in:
Lucas Clemente
2017-06-07 09:37:43 +02:00
parent 6be03b54d2
commit f2959aa74a
3 changed files with 32 additions and 58 deletions

View File

@@ -4,7 +4,7 @@ import (
"errors"
"fmt"
"net"
"sync/atomic"
"sync"
"time"
"github.com/lucas-clemente/quic-go/ackhandler"
@@ -78,7 +78,7 @@ type session struct {
// closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError
runClosed chan struct{}
closed uint32 // atomic bool
closeOnce sync.Once
// when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed
@@ -290,7 +290,7 @@ runLoop:
s.tryQueueingUndecryptablePacket(p)
continue
}
s.close(err)
s.closeLocal(err)
continue
}
// This is a bit unclean, but works properly, since the packet always
@@ -319,16 +319,16 @@ runLoop:
}
if err := s.sendPacket(); err != nil {
s.close(err)
s.closeLocal(err)
}
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
}
if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout {
s.close(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
}
s.garbageCollectStreams()
}
@@ -462,7 +462,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame:
err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame:
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
s.close(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame:
@@ -548,48 +548,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
}
func (s *session) registerClose(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
}
func (s *session) close(e error, remoteClose bool) {
s.closeOnce.Do(func() {
s.closeChan <- closeError{err: e, remote: remoteClose}
})
}
if e == nil {
e = qerr.PeerGoingAway
}
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
func (s *session) closeLocal(e error) {
s.close(e, false)
}
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning
func (s *session) Close(e error) error {
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
// wait for the run loop to finish
s.close(e, false)
<-s.runClosed
return err
}
// close the connection. Use this when called from the run loop
func (s *session) close(e error) error {
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
return err
return nil
}
func (s *session) handleCloseError(closeErr closeError) error {
if closeErr.err == nil {
closeErr.err = qerr.PeerGoingAway
}
var quicErr *qerr.QuicError
var ok bool
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
@@ -602,13 +583,12 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
}
s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion {
return nil
}
s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr)
// If this is a remote close we're done here
if closeErr.remote {
return nil
@@ -620,13 +600,6 @@ func (s *session) handleCloseError(closeErr closeError) error {
return s.sendConnectionClose(quicErr)
}
func (s *session) closeStreamsWithError(err error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
str.Cancel(err)
return true, nil
})
}
func (s *session) sendPacket() error {
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
for {

View File

@@ -8,7 +8,6 @@ import (
"net"
"runtime/pprof"
"strings"
"sync/atomic"
"time"
. "github.com/onsi/ginkgo"
@@ -357,9 +356,9 @@ var _ = Describe("Session", func() {
p := make([]byte, 4)
_, err = str.Read(p)
Expect(err).ToNot(HaveOccurred())
sess.closeStreamsWithError(testErr)
sess.handleCloseError(closeError{err: testErr, remote: true})
_, err = str.Read(p)
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
sess.garbageCollectStreams()
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
@@ -372,9 +371,9 @@ var _ = Describe("Session", func() {
str, err := sess.streamsMap.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
sess.closeStreamsWithError(testErr)
sess.handleCloseError(closeError{err: testErr, remote: true})
_, err = str.Read([]byte{0})
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, testErr.Error())))
sess.garbageCollectStreams()
str, err = sess.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
@@ -714,7 +713,7 @@ var _ = Describe("Session", func() {
Expect(sess.runClosed).ToNot(BeClosed())
sess.Close(errCloseSessionForNewVersion)
Eventually(func() error { return err }).Should(HaveOccurred())
Expect(err).To(MatchError(errCloseSessionForNewVersion))
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error())))
Eventually(sess.runClosed).Should(BeClosed())
})
})
@@ -760,7 +759,6 @@ var _ = Describe("Session", func() {
It("closes the session in order to replace it with another QUIC version", func() {
sess.Close(errCloseSessionForNewVersion)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue())
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
})

View File

@@ -319,8 +319,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
m.mutex.Unlock()
for _, s := range m.openStreams {
m.streams[s].Cancel(err)
}
}