send ConnectionCloses and PublicReset from the session run loop

This avoids several race conditions that can occur when trying to pack
packets and to write to the underlying connection concurrently.
This commit is contained in:
Marten Seemann
2017-05-10 14:22:30 +08:00
parent 75308ab56e
commit 69d5f3c262
2 changed files with 62 additions and 48 deletions

View File

@@ -39,6 +39,11 @@ var (
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(session Session, isForwardSecure bool)
type closeError struct {
err error
remote bool
}
// A Session is a QUIC session
type session struct {
connectionID protocol.ConnectionID
@@ -67,7 +72,7 @@ type session struct {
receivedPackets chan *receivedPacket
sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate.
closeChan chan error
closeChan chan closeError
runClosed chan struct{}
closed uint32 // atomic bool
@@ -178,7 +183,7 @@ func (s *session) setup() {
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan error, 1)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
@@ -201,7 +206,7 @@ func (s *session) run() error {
}
}()
var closeErr error
var closeErr closeError
aeadChanged := s.aeadChanged
runLoop:
@@ -215,7 +220,6 @@ runLoop:
s.maybeResetTimer()
var err error
select {
case closeErr = <-s.closeChan:
break runLoop
@@ -227,10 +231,14 @@ runLoop:
// We do all the interesting stuff after the switch statement, so
// nothing to see here.
case p := <-s.receivedPackets:
err = s.handlePacketImpl(p)
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p)
continue
err := s.handlePacketImpl(p)
if err != nil {
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p)
continue
}
s.close(err)
break runLoop
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
@@ -248,10 +256,6 @@ runLoop:
}
}
if err != nil {
s.close(err)
}
now := time.Now()
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
// This could cause packets to be retransmitted, so check it before trying
@@ -274,8 +278,9 @@ runLoop:
s.garbageCollectStreams()
}
s.handleCloseError(closeErr)
close(s.runClosed)
return closeErr
return closeErr.err
}
func (s *session) maybeResetTimer() {
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame:
err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame:
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame:
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
}
func (s *session) registerClose(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
}
if e == nil {
e = qerr.PeerGoingAway
}
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
}
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning
func (s *session) Close(e error) error {
err := s.closeImpl(e, false)
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
// close the connection. Use this when called from the run loop
func (s *session) close(e error) error {
err := s.closeImpl(e, false)
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
return err
}
func (s *session) closeImpl(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
func (s *session) handleCloseError(closeErr closeError) error {
var quicErr *qerr.QuicError
var ok bool
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
quicErr = qerr.ToQuicError(closeErr.err)
}
if e == nil {
e = qerr.PeerGoingAway
}
defer func() {
s.closeChan <- e
}()
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
return nil
}
quicErr := qerr.ToQuicError(e)
// Don't log 'normal' reasons
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
utils.Infof("Closing connection %x", s.connectionID)
} else {
utils.Errorf("Closing session with error: %s", e.Error())
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
}
if closeErr.err == errCloseSessionForNewVersion {
return nil
}
s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr)
if remoteClose {
// If this is a remote close we're done here
// If this is a remote close we're done here
if closeErr.remote {
return nil
}
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
return s.sendPublicReset(s.lastRcvdPacketNumber)
}
s.sendConnectionClose(quicErr)
return nil
return s.sendConnectionClose(quicErr)
}
func (s *session) closeStreamsWithError(err error) {

View File

@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
})
It("handles CONNECTION_CLOSE frames", func() {
It("handles CONNECTION_CLOSE frames", func(done Done) {
go sess.run()
str, _ := sess.GetOrOpenStream(5)
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
Expect(err).NotTo(HaveOccurred())
Eventually(sess.runClosed).Should(BeClosed())
_, err = str.Read([]byte{0})
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
close(done)
})
Context("accepting streams", func() {
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
})
It("stops accepting when the session is closed after version negotiation", func() {
testErr := errCloseSessionForNewVersion
var err error
go func() {
_, err = sess.AcceptStream()
}()
go sess.run()
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
sess.Close(testErr)
Expect(sess.runClosed).ToNot(BeClosed())
sess.Close(errCloseSessionForNewVersion)
Eventually(func() error { return err }).Should(HaveOccurred())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(errCloseSessionForNewVersion))
Eventually(sess.runClosed).Should(BeClosed())
})
})
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
go sess.run()
sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty())
sess.closeImpl(nil, true)
Eventually(sess.runClosed).Should(BeClosed())
Expect(sess.Close(nil)).To(Succeed())
})
It("unqueues undecryptable packets for later decryption", func() {