forked from quic-go/quic-go
Merge pull request #608 from lucas-clemente/better-close
send ConnectionCloses and PublicReset from the session run loop
This commit is contained in:
99
session.go
99
session.go
@@ -39,6 +39,11 @@ var (
|
||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
}
|
||||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
connectionID protocol.ConnectionID
|
||||
@@ -67,7 +72,7 @@ type session struct {
|
||||
receivedPackets chan *receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
closeChan chan error
|
||||
closeChan chan closeError
|
||||
runClosed chan struct{}
|
||||
closed uint32 // atomic bool
|
||||
|
||||
@@ -178,11 +183,11 @@ func (s *session) setup() {
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
|
||||
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan error, 1)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
|
||||
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
s.runClosed = make(chan struct{}, 1)
|
||||
s.runClosed = make(chan struct{})
|
||||
|
||||
s.timer = time.NewTimer(0)
|
||||
s.lastNetworkActivityTime = now
|
||||
@@ -201,7 +206,7 @@ func (s *session) run() error {
|
||||
}
|
||||
}()
|
||||
|
||||
var closeErr error
|
||||
var closeErr closeError
|
||||
aeadChanged := s.aeadChanged
|
||||
|
||||
runLoop:
|
||||
@@ -215,7 +220,6 @@ runLoop:
|
||||
|
||||
s.maybeResetTimer()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case closeErr = <-s.closeChan:
|
||||
break runLoop
|
||||
@@ -227,10 +231,14 @@ runLoop:
|
||||
// We do all the interesting stuff after the switch statement, so
|
||||
// nothing to see here.
|
||||
case p := <-s.receivedPackets:
|
||||
err = s.handlePacketImpl(p)
|
||||
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
continue
|
||||
err := s.handlePacketImpl(p)
|
||||
if err != nil {
|
||||
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
continue
|
||||
}
|
||||
s.close(err)
|
||||
break runLoop
|
||||
}
|
||||
// This is a bit unclean, but works properly, since the packet always
|
||||
// begins with the public header and we never copy it.
|
||||
@@ -248,10 +256,6 @@ runLoop:
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.close(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
|
||||
// This could cause packets to be retransmitted, so check it before trying
|
||||
@@ -274,8 +278,9 @@ runLoop:
|
||||
s.garbageCollectStreams()
|
||||
}
|
||||
|
||||
s.runClosed <- struct{}{}
|
||||
return closeErr
|
||||
s.handleCloseError(closeErr)
|
||||
close(s.runClosed)
|
||||
return closeErr.err
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
|
||||
case *frames.AckFrame:
|
||||
err = s.handleAckFrame(frame)
|
||||
case *frames.ConnectionCloseFrame:
|
||||
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
case *frames.GoawayFrame:
|
||||
err = errors.New("unimplemented: handling GOAWAY frames")
|
||||
case *frames.StopWaitingFrame:
|
||||
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
|
||||
}
|
||||
|
||||
func (s *session) registerClose(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
}
|
||||
|
||||
s.closeChan <- closeError{err: e, remote: remoteClose}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *session) Close(e error) error {
|
||||
err := s.closeImpl(e, false)
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
|
||||
|
||||
// close the connection. Use this when called from the run loop
|
||||
func (s *session) close(e error) error {
|
||||
err := s.closeImpl(e, false)
|
||||
err := s.registerClose(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) closeImpl(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return errSessionAlreadyClosed
|
||||
func (s *session) handleCloseError(closeErr closeError) error {
|
||||
var quicErr *qerr.QuicError
|
||||
var ok bool
|
||||
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
|
||||
quicErr = qerr.ToQuicError(closeErr.err)
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
defer func() {
|
||||
s.closeChan <- e
|
||||
}()
|
||||
|
||||
if e == errCloseSessionForNewVersion {
|
||||
s.streamsMap.CloseWithError(e)
|
||||
s.closeStreamsWithError(e)
|
||||
return nil
|
||||
}
|
||||
|
||||
quicErr := qerr.ToQuicError(e)
|
||||
|
||||
// Don't log 'normal' reasons
|
||||
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
|
||||
utils.Infof("Closing connection %x", s.connectionID)
|
||||
} else {
|
||||
utils.Errorf("Closing session with error: %s", e.Error())
|
||||
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
|
||||
}
|
||||
|
||||
if closeErr.err == errCloseSessionForNewVersion {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
s.closeStreamsWithError(quicErr)
|
||||
|
||||
if remoteClose {
|
||||
// If this is a remote close we're done here
|
||||
// If this is a remote close we're done here
|
||||
if closeErr.remote {
|
||||
return nil
|
||||
}
|
||||
|
||||
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
|
||||
return s.sendPublicReset(s.lastRcvdPacketNumber)
|
||||
}
|
||||
s.sendConnectionClose(quicErr)
|
||||
return nil
|
||||
return s.sendConnectionClose(quicErr)
|
||||
}
|
||||
|
||||
func (s *session) closeStreamsWithError(err error) {
|
||||
|
||||
@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("handles CONNECTION_CLOSE frames", func() {
|
||||
It("handles CONNECTION_CLOSE frames", func(done Done) {
|
||||
go sess.run()
|
||||
str, _ := sess.GetOrOpenStream(5)
|
||||
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
|
||||
close(done)
|
||||
})
|
||||
|
||||
Context("accepting streams", func() {
|
||||
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("stops accepting when the session is closed after version negotiation", func() {
|
||||
testErr := errCloseSessionForNewVersion
|
||||
var err error
|
||||
go func() {
|
||||
_, err = sess.AcceptStream()
|
||||
}()
|
||||
go sess.run()
|
||||
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
|
||||
sess.Close(testErr)
|
||||
Expect(sess.runClosed).ToNot(BeClosed())
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(func() error { return err }).Should(HaveOccurred())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(err).To(MatchError(errCloseSessionForNewVersion))
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -646,7 +650,7 @@ var _ = Describe("Session", func() {
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
})
|
||||
|
||||
It("only closes once", func() {
|
||||
@@ -654,7 +658,7 @@ var _ = Describe("Session", func() {
|
||||
sess.Close(nil)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes streams with proper error", func() {
|
||||
@@ -669,7 +673,7 @@ var _ = Describe("Session", func() {
|
||||
n, err = s.Write([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
@@ -683,7 +687,7 @@ var _ = Describe("Session", func() {
|
||||
sess.Close(handshake.ErrHOLExperiment)
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
|
||||
Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1230,7 +1234,7 @@ var _ = Describe("Session", func() {
|
||||
sess.scheduleSending() // wake up the run loop
|
||||
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
|
||||
Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
|
||||
Eventually(sess.runClosed).Should(Receive())
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() {
|
||||
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
|
||||
go sess.run()
|
||||
sendUndecryptablePackets()
|
||||
Consistently(sess.undecryptablePackets).Should(BeEmpty())
|
||||
sess.closeImpl(nil, true)
|
||||
Eventually(sess.runClosed).Should(Receive())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
})
|
||||
|
||||
It("unqueues undecryptable packets for later decryption", func() {
|
||||
@@ -1305,7 +1308,7 @@ var _ = Describe("Session", func() {
|
||||
sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1313,7 +1316,7 @@ var _ = Describe("Session", func() {
|
||||
sess.sessionCreationTime = time.Now().Add(-time.Hour)
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1323,7 +1326,7 @@ var _ = Describe("Session", func() {
|
||||
sess.packer.connectionParameters = sess.connectionParameters
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1333,7 +1336,7 @@ var _ = Describe("Session", func() {
|
||||
sess.packer.connectionParameters = sess.connectionParameters
|
||||
sess.run() // Would normally not return
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(Receive())
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user