forked from quic-go/quic-go
send ConnectionCloses and PublicReset from the session run loop
This avoids several race conditions that can occur when trying to pack packets and to write to the underlying connection concurrently.
This commit is contained in:
95
session.go
95
session.go
@@ -39,6 +39,11 @@ var (
|
||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
}
|
||||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
connectionID protocol.ConnectionID
|
||||
@@ -67,7 +72,7 @@ type session struct {
|
||||
receivedPackets chan *receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
closeChan chan error
|
||||
closeChan chan closeError
|
||||
runClosed chan struct{}
|
||||
closed uint32 // atomic bool
|
||||
|
||||
@@ -178,7 +183,7 @@ func (s *session) setup() {
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
|
||||
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan error, 1)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
|
||||
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
@@ -201,7 +206,7 @@ func (s *session) run() error {
|
||||
}
|
||||
}()
|
||||
|
||||
var closeErr error
|
||||
var closeErr closeError
|
||||
aeadChanged := s.aeadChanged
|
||||
|
||||
runLoop:
|
||||
@@ -215,7 +220,6 @@ runLoop:
|
||||
|
||||
s.maybeResetTimer()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
break runLoop
|
||||
@@ -227,10 +231,14 @@ runLoop:
|
||||
// We do all the interesting stuff after the switch statement, so
|
||||
// nothing to see here.
|
||||
case p := <-s.receivedPackets:
|
||||
err = s.handlePacketImpl(p)
|
||||
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
continue
|
||||
err := s.handlePacketImpl(p)
|
||||
if err != nil {
|
||||
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
continue
|
||||
}
|
||||
s.close(err)
|
||||
break runLoop
|
||||
}
|
||||
// This is a bit unclean, but works properly, since the packet always
|
||||
// begins with the public header and we never copy it.
|
||||
@@ -248,10 +256,6 @@ runLoop:
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.close(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
|
||||
// This could cause packets to be retransmitted, so check it before trying
|
||||
@@ -274,8 +278,9 @@ runLoop:
|
||||
s.garbageCollectStreams()
|
||||
}
|
||||
|
||||
s.handleCloseError(closeErr)
|
||||
close(s.runClosed)
|
||||
return closeErr
|
||||
return closeErr.err
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
|
||||
case *frames.AckFrame:
|
||||
err = s.handleAckFrame(frame)
|
||||
case *frames.ConnectionCloseFrame:
|
||||
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
case *frames.GoawayFrame:
|
||||
err = errors.New("unimplemented: handling GOAWAY frames")
|
||||
case *frames.StopWaitingFrame:
|
||||
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
|
||||
}
|
||||
|
||||
func (s *session) registerClose(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
}
|
||||
|
||||
s.closeChan <- closeError{err: e, remote: remoteClose}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *session) Close(e error) error {
|
||||
err := s.closeImpl(e, false)
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
|
||||
|
||||
// close the connection. Use this when called from the run loop
|
||||
func (s *session) close(e error) error {
|
||||
err := s.closeImpl(e, false)
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) closeImpl(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
func (s *session) handleCloseError(closeErr closeError) error {
|
||||
var quicErr *qerr.QuicError
|
||||
var ok bool
|
||||
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
|
||||
quicErr = qerr.ToQuicError(closeErr.err)
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
defer func() {
|
||||
s.closeChan <- e
|
||||
}()
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
return nil
|
||||
}
|
||||
|
||||
quicErr := qerr.ToQuicError(e)
|
||||
|
||||
// Don't log 'normal' reasons
|
||||
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
|
||||
utils.Infof("Closing connection %x", s.connectionID)
|
||||
} else {
|
||||
utils.Errorf("Closing session with error: %s", e.Error())
|
||||
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
|
||||
}
|
||||
|
||||
if closeErr.err == errCloseSessionForNewVersion {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
s.closeStreamsWithError(quicErr)
|
||||
|
||||
if remoteClose {
|
||||
// If this is a remote close we're done here
|
||||
// If this is a remote close we're done here
|
||||
if closeErr.remote {
|
||||
return nil
|
||||
}
|
||||
|
||||
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
|
||||
return s.sendPublicReset(s.lastRcvdPacketNumber)
|
||||
}
|
||||
s.sendConnectionClose(quicErr)
|
||||
return nil
|
||||
return s.sendConnectionClose(quicErr)
|
||||
}
|
||||
|
||||
func (s *session) closeStreamsWithError(err error) {
|
||||
|
||||
@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("handles CONNECTION_CLOSE frames", func() {
|
||||
It("handles CONNECTION_CLOSE frames", func(done Done) {
|
||||
go sess.run()
|
||||
str, _ := sess.GetOrOpenStream(5)
|
||||
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
|
||||
close(done)
|
||||
})
|
||||
|
||||
Context("accepting streams", func() {
|
||||
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("stops accepting when the session is closed after version negotiation", func() {
|
||||
testErr := errCloseSessionForNewVersion
|
||||
var err error
|
||||
go func() {
|
||||
_, err = sess.AcceptStream()
|
||||
}()
|
||||
go sess.run()
|
||||
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
|
||||
sess.Close(testErr)
|
||||
Expect(sess.runClosed).ToNot(BeClosed())
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(func() error { return err }).Should(HaveOccurred())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(errCloseSessionForNewVersion))
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
|
||||
go sess.run()
|
||||
sendUndecryptablePackets()
|
||||
Consistently(sess.undecryptablePackets).Should(BeEmpty())
|
||||
sess.closeImpl(nil, true)
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
})
|
||||
|
||||
It("unqueues undecryptable packets for later decryption", func() {
|
||||
|
||||
Reference in New Issue
Block a user