send ConnectionCloses and PublicReset from the session run loop

This avoids several race conditions that can occur when trying to pack
packets and to write to the underlying connection concurrently.
This commit is contained in:
Marten Seemann
2017-05-10 14:22:30 +08:00
parent 75308ab56e
commit 69d5f3c262
2 changed files with 62 additions and 48 deletions

View File

@@ -39,6 +39,11 @@ var (
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(session Session, isForwardSecure bool) type cryptoChangeCallback func(session Session, isForwardSecure bool)
type closeError struct {
err error
remote bool
}
// A Session is a QUIC session // A Session is a QUIC session
type session struct { type session struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
@@ -67,7 +72,7 @@ type session struct {
receivedPackets chan *receivedPacket receivedPackets chan *receivedPacket
sendingScheduled chan struct{} sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate. // closeChan is used to notify the run loop that it should terminate.
closeChan chan error closeChan chan closeError
runClosed chan struct{} runClosed chan struct{}
closed uint32 // atomic bool closed uint32 // atomic bool
@@ -178,7 +183,7 @@ func (s *session) setup() {
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan error, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.aeadChanged = make(chan protocol.EncryptionLevel, 2) s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
@@ -201,7 +206,7 @@ func (s *session) run() error {
} }
}() }()
var closeErr error var closeErr closeError
aeadChanged := s.aeadChanged aeadChanged := s.aeadChanged
runLoop: runLoop:
@@ -215,7 +220,6 @@ runLoop:
s.maybeResetTimer() s.maybeResetTimer()
var err error
select { select {
case closeErr = <-s.closeChan: case closeErr = <-s.closeChan:
break runLoop break runLoop
@@ -227,10 +231,14 @@ runLoop:
// We do all the interesting stuff after the switch statement, so // We do all the interesting stuff after the switch statement, so
// nothing to see here. // nothing to see here.
case p := <-s.receivedPackets: case p := <-s.receivedPackets:
err = s.handlePacketImpl(p) err := s.handlePacketImpl(p)
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { if err != nil {
s.tryQueueingUndecryptablePacket(p) if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
continue s.tryQueueingUndecryptablePacket(p)
continue
}
s.close(err)
break runLoop
} }
// This is a bit unclean, but works properly, since the packet always // This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it. // begins with the public header and we never copy it.
@@ -248,10 +256,6 @@ runLoop:
} }
} }
if err != nil {
s.close(err)
}
now := time.Now() now := time.Now()
if s.sentPacketHandler.GetAlarmTimeout().Before(now) { if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
// This could cause packets to be retransmitted, so check it before trying // This could cause packets to be retransmitted, so check it before trying
@@ -274,8 +278,9 @@ runLoop:
s.garbageCollectStreams() s.garbageCollectStreams()
} }
s.handleCloseError(closeErr)
close(s.runClosed) close(s.runClosed)
return closeErr return closeErr.err
} }
func (s *session) maybeResetTimer() { func (s *session) maybeResetTimer() {
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame: case *frames.AckFrame:
err = s.handleAckFrame(frame) err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame: case *frames.ConnectionCloseFrame:
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
case *frames.GoawayFrame: case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames") err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame: case *frames.StopWaitingFrame:
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
} }
func (s *session) registerClose(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
}
if e == nil {
e = qerr.PeerGoingAway
}
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
}
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning // It waits until the run loop has stopped before returning
func (s *session) Close(e error) error { func (s *session) Close(e error) error {
err := s.closeImpl(e, false) err := s.registerClose(e, false)
if err == errSessionAlreadyClosed { if err == errSessionAlreadyClosed {
return nil return nil
} }
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
// close the connection. Use this when called from the run loop // close the connection. Use this when called from the run loop
func (s *session) close(e error) error { func (s *session) close(e error) error {
err := s.closeImpl(e, false) err := s.registerClose(e, false)
if err == errSessionAlreadyClosed { if err == errSessionAlreadyClosed {
return nil return nil
} }
return err return err
} }
func (s *session) closeImpl(e error, remoteClose bool) error { func (s *session) handleCloseError(closeErr closeError) error {
// Only close once var quicErr *qerr.QuicError
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { var ok bool
return errSessionAlreadyClosed if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
quicErr = qerr.ToQuicError(closeErr.err)
} }
if e == nil {
e = qerr.PeerGoingAway
}
defer func() {
s.closeChan <- e
}()
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
return nil
}
quicErr := qerr.ToQuicError(e)
// Don't log 'normal' reasons // Don't log 'normal' reasons
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
utils.Infof("Closing connection %x", s.connectionID) utils.Infof("Closing connection %x", s.connectionID)
} else { } else {
utils.Errorf("Closing session with error: %s", e.Error()) utils.Errorf("Closing session with error: %s", closeErr.err.Error())
}
if closeErr.err == errCloseSessionForNewVersion {
return nil
} }
s.streamsMap.CloseWithError(quicErr) s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr) s.closeStreamsWithError(quicErr)
if remoteClose { // If this is a remote close we're done here
// If this is a remote close we're done here if closeErr.remote {
return nil return nil
} }
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment { if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
return s.sendPublicReset(s.lastRcvdPacketNumber) return s.sendPublicReset(s.lastRcvdPacketNumber)
} }
s.sendConnectionClose(quicErr) return s.sendConnectionClose(quicErr)
return nil
} }
func (s *session) closeStreamsWithError(err error) { func (s *session) closeStreamsWithError(err error) {

View File

@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("handles CONNECTION_CLOSE frames", func() { It("handles CONNECTION_CLOSE frames", func(done Done) {
go sess.run()
str, _ := sess.GetOrOpenStream(5) str, _ := sess.GetOrOpenStream(5)
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(sess.runClosed).Should(BeClosed())
_, err = str.Read([]byte{0}) _, err = str.Read([]byte{0})
Expect(err).To(MatchError(qerr.Error(42, "foobar"))) Expect(err).To(MatchError(qerr.Error(42, "foobar")))
close(done)
}) })
Context("accepting streams", func() { Context("accepting streams", func() {
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
}) })
It("stops accepting when the session is closed after version negotiation", func() { It("stops accepting when the session is closed after version negotiation", func() {
testErr := errCloseSessionForNewVersion
var err error var err error
go func() { go func() {
_, err = sess.AcceptStream() _, err = sess.AcceptStream()
}() }()
go sess.run() go sess.run()
Consistently(func() error { return err }).ShouldNot(HaveOccurred()) Consistently(func() error { return err }).ShouldNot(HaveOccurred())
sess.Close(testErr) Expect(sess.runClosed).ToNot(BeClosed())
sess.Close(errCloseSessionForNewVersion)
Eventually(func() error { return err }).Should(HaveOccurred()) Eventually(func() error { return err }).Should(HaveOccurred())
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(errCloseSessionForNewVersion))
Eventually(sess.runClosed).Should(BeClosed())
}) })
}) })
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
go sess.run() go sess.run()
sendUndecryptablePackets() sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty()) Consistently(sess.undecryptablePackets).Should(BeEmpty())
sess.closeImpl(nil, true) Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.runClosed).Should(BeClosed())
}) })
It("unqueues undecryptable packets for later decryption", func() { It("unqueues undecryptable packets for later decryption", func() {