Merge pull request #608 from lucas-clemente/better-close

send ConnectionCloses and PublicReset from the session run loop
This commit is contained in:
Marten Seemann
2017-05-10 18:57:01 +08:00
committed by GitHub
2 changed files with 73 additions and 59 deletions

View File

@@ -39,6 +39,11 @@ var (
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(session Session, isForwardSecure bool)
type closeError struct {
err error
remote bool
}
// A Session is a QUIC session
type session struct {
connectionID protocol.ConnectionID
@@ -67,7 +72,7 @@ type session struct {
receivedPackets chan *receivedPacket
sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate.
closeChan chan error
closeChan chan closeError
runClosed chan struct{}
closed uint32 // atomic bool
@@ -178,11 +183,11 @@ func (s *session) setup() {
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan error, 1)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
s.runClosed = make(chan struct{}, 1)
s.runClosed = make(chan struct{})
s.timer = time.NewTimer(0)
s.lastNetworkActivityTime = now
@@ -201,7 +206,7 @@ func (s *session) run() error {
}
}()
var closeErr error
var closeErr closeError
aeadChanged := s.aeadChanged
runLoop:
@@ -215,7 +220,6 @@ runLoop:
s.maybeResetTimer()
var err error
select {
case closeErr = <-s.closeChan:
break runLoop
@@ -227,10 +231,14 @@ runLoop:
// We do all the interesting stuff after the switch statement, so
// nothing to see here.
case p := <-s.receivedPackets:
err = s.handlePacketImpl(p)
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p)
continue
err := s.handlePacketImpl(p)
if err != nil {
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p)
continue
}
s.close(err)
break runLoop
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
@@ -248,10 +256,6 @@ runLoop:
}
}
if err != nil {
s.close(err)
}
now := time.Now()
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
// This could cause packets to be retransmitted, so check it before trying
@@ -274,8 +278,9 @@ runLoop:
s.garbageCollectStreams()
}
s.runClosed <- struct{}{}
return closeErr
s.handleCloseError(closeErr)
close(s.runClosed)
return closeErr.err
}
func (s *session) maybeResetTimer() {
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame:
err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame:
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame:
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
}
func (s *session) registerClose(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
}
if e == nil {
e = qerr.PeerGoingAway
}
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
}
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning
func (s *session) Close(e error) error {
err := s.closeImpl(e, false)
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
// close the connection. Use this when called from the run loop
func (s *session) close(e error) error {
err := s.closeImpl(e, false)
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
return err
}
func (s *session) closeImpl(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return errSessionAlreadyClosed
func (s *session) handleCloseError(closeErr closeError) error {
var quicErr *qerr.QuicError
var ok bool
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
quicErr = qerr.ToQuicError(closeErr.err)
}
if e == nil {
e = qerr.PeerGoingAway
}
defer func() {
s.closeChan <- e
}()
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
return nil
}
quicErr := qerr.ToQuicError(e)
// Don't log 'normal' reasons
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
utils.Infof("Closing connection %x", s.connectionID)
} else {
utils.Errorf("Closing session with error: %s", e.Error())
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
}
if closeErr.err == errCloseSessionForNewVersion {
return nil
}
s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr)
if remoteClose {
// If this is a remote close we're done here
// If this is a remote close we're done here
if closeErr.remote {
return nil
}
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
return s.sendPublicReset(s.lastRcvdPacketNumber)
}
s.sendConnectionClose(quicErr)
return nil
return s.sendConnectionClose(quicErr)
}
func (s *session) closeStreamsWithError(err error) {

View File

@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
})
It("handles CONNECTION_CLOSE frames", func() {
It("handles CONNECTION_CLOSE frames", func(done Done) {
go sess.run()
str, _ := sess.GetOrOpenStream(5)
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
Expect(err).NotTo(HaveOccurred())
Eventually(sess.runClosed).Should(BeClosed())
_, err = str.Read([]byte{0})
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
close(done)
})
Context("accepting streams", func() {
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
})
It("stops accepting when the session is closed after version negotiation", func() {
testErr := errCloseSessionForNewVersion
var err error
go func() {
_, err = sess.AcceptStream()
}()
go sess.run()
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
sess.Close(testErr)
Expect(sess.runClosed).ToNot(BeClosed())
sess.Close(errCloseSessionForNewVersion)
Eventually(func() error { return err }).Should(HaveOccurred())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(errCloseSessionForNewVersion))
Eventually(sess.runClosed).Should(BeClosed())
})
})
@@ -646,7 +650,7 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
Expect(sess.runClosed).To(BeClosed())
})
It("only closes once", func() {
@@ -654,7 +658,7 @@ var _ = Describe("Session", func() {
sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
Expect(sess.runClosed).To(BeClosed())
})
It("closes streams with proper error", func() {
@@ -669,7 +673,7 @@ var _ = Describe("Session", func() {
n, err = s.Write([]byte{0})
Expect(n).To(BeZero())
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
Expect(sess.runClosed).To(BeClosed())
})
It("closes the session in order to replace it with another QUIC version", func() {
@@ -683,7 +687,7 @@ var _ = Describe("Session", func() {
sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
Expect(sess.runClosed).To(BeClosed())
})
})
@@ -1230,7 +1234,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending() // wake up the run loop
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Eventually(sess.runClosed).Should(Receive())
Eventually(sess.runClosed).Should(BeClosed())
})
It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() {
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
go sess.run()
sendUndecryptablePackets()
Consistently(sess.undecryptablePackets).Should(BeEmpty())
sess.closeImpl(nil, true)
Eventually(sess.runClosed).Should(Receive())
Expect(sess.Close(nil)).To(Succeed())
})
It("unqueues undecryptable packets for later decryption", func() {
@@ -1305,7 +1308,7 @@ var _ = Describe("Session", func() {
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(Receive())
Expect(sess.runClosed).To(BeClosed())
close(done)
})
@@ -1313,7 +1316,7 @@ var _ = Describe("Session", func() {
sess.sessionCreationTime = time.Now().Add(-time.Hour)
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(sess.runClosed).To(Receive())
Expect(sess.runClosed).To(BeClosed())
close(done)
})
@@ -1323,7 +1326,7 @@ var _ = Describe("Session", func() {
sess.packer.connectionParameters = sess.connectionParameters
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(Receive())
Expect(sess.runClosed).To(BeClosed())
close(done)
})
@@ -1333,7 +1336,7 @@ var _ = Describe("Session", func() {
sess.packer.connectionParameters = sess.connectionParameters
sess.run() // Would normally not return
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(Receive())
Expect(sess.runClosed).To(BeClosed())
close(done)
})
})