forked from quic-go/quic-go
send ConnectionCloses and PublicReset from the session run loop
This avoids several race conditions that can occur when trying to pack packets and to write to the underlying connection concurrently.
This commit is contained in:
95
session.go
95
session.go
@@ -39,6 +39,11 @@ var (
|
|||||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||||
|
|
||||||
|
type closeError struct {
|
||||||
|
err error
|
||||||
|
remote bool
|
||||||
|
}
|
||||||
|
|
||||||
// A Session is a QUIC session
|
// A Session is a QUIC session
|
||||||
type session struct {
|
type session struct {
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
@@ -67,7 +72,7 @@ type session struct {
|
|||||||
receivedPackets chan *receivedPacket
|
receivedPackets chan *receivedPacket
|
||||||
sendingScheduled chan struct{}
|
sendingScheduled chan struct{}
|
||||||
// closeChan is used to notify the run loop that it should terminate.
|
// closeChan is used to notify the run loop that it should terminate.
|
||||||
closeChan chan error
|
closeChan chan closeError
|
||||||
runClosed chan struct{}
|
runClosed chan struct{}
|
||||||
closed uint32 // atomic bool
|
closed uint32 // atomic bool
|
||||||
|
|
||||||
@@ -178,7 +183,7 @@ func (s *session) setup() {
|
|||||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
|
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
|
||||||
|
|
||||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||||
s.closeChan = make(chan error, 1)
|
s.closeChan = make(chan closeError, 1)
|
||||||
s.sendingScheduled = make(chan struct{}, 1)
|
s.sendingScheduled = make(chan struct{}, 1)
|
||||||
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
|
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
|
||||||
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
@@ -201,7 +206,7 @@ func (s *session) run() error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var closeErr error
|
var closeErr closeError
|
||||||
aeadChanged := s.aeadChanged
|
aeadChanged := s.aeadChanged
|
||||||
|
|
||||||
runLoop:
|
runLoop:
|
||||||
@@ -215,7 +220,6 @@ runLoop:
|
|||||||
|
|
||||||
s.maybeResetTimer()
|
s.maybeResetTimer()
|
||||||
|
|
||||||
var err error
|
|
||||||
select {
|
select {
|
||||||
case closeErr = <-s.closeChan:
|
case closeErr = <-s.closeChan:
|
||||||
break runLoop
|
break runLoop
|
||||||
@@ -227,10 +231,14 @@ runLoop:
|
|||||||
// We do all the interesting stuff after the switch statement, so
|
// We do all the interesting stuff after the switch statement, so
|
||||||
// nothing to see here.
|
// nothing to see here.
|
||||||
case p := <-s.receivedPackets:
|
case p := <-s.receivedPackets:
|
||||||
err = s.handlePacketImpl(p)
|
err := s.handlePacketImpl(p)
|
||||||
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
if err != nil {
|
||||||
s.tryQueueingUndecryptablePacket(p)
|
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
|
||||||
continue
|
s.tryQueueingUndecryptablePacket(p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.close(err)
|
||||||
|
break runLoop
|
||||||
}
|
}
|
||||||
// This is a bit unclean, but works properly, since the packet always
|
// This is a bit unclean, but works properly, since the packet always
|
||||||
// begins with the public header and we never copy it.
|
// begins with the public header and we never copy it.
|
||||||
@@ -248,10 +256,6 @@ runLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
s.close(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
|
if s.sentPacketHandler.GetAlarmTimeout().Before(now) {
|
||||||
// This could cause packets to be retransmitted, so check it before trying
|
// This could cause packets to be retransmitted, so check it before trying
|
||||||
@@ -274,8 +278,9 @@ runLoop:
|
|||||||
s.garbageCollectStreams()
|
s.garbageCollectStreams()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.handleCloseError(closeErr)
|
||||||
close(s.runClosed)
|
close(s.runClosed)
|
||||||
return closeErr
|
return closeErr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) maybeResetTimer() {
|
func (s *session) maybeResetTimer() {
|
||||||
@@ -396,7 +401,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
|
|||||||
case *frames.AckFrame:
|
case *frames.AckFrame:
|
||||||
err = s.handleAckFrame(frame)
|
err = s.handleAckFrame(frame)
|
||||||
case *frames.ConnectionCloseFrame:
|
case *frames.ConnectionCloseFrame:
|
||||||
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||||
case *frames.GoawayFrame:
|
case *frames.GoawayFrame:
|
||||||
err = errors.New("unimplemented: handling GOAWAY frames")
|
err = errors.New("unimplemented: handling GOAWAY frames")
|
||||||
case *frames.StopWaitingFrame:
|
case *frames.StopWaitingFrame:
|
||||||
@@ -482,10 +487,29 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
|
|||||||
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
|
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) registerClose(e error, remoteClose bool) error {
|
||||||
|
// Only close once
|
||||||
|
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||||
|
return errSessionAlreadyClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if e == nil {
|
||||||
|
e = qerr.PeerGoingAway
|
||||||
|
}
|
||||||
|
|
||||||
|
if e == errCloseSessionForNewVersion {
|
||||||
|
s.streamsMap.CloseWithError(e)
|
||||||
|
s.closeStreamsWithError(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.closeChan <- closeError{err: e, remote: remoteClose}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
||||||
// It waits until the run loop has stopped before returning
|
// It waits until the run loop has stopped before returning
|
||||||
func (s *session) Close(e error) error {
|
func (s *session) Close(e error) error {
|
||||||
err := s.closeImpl(e, false)
|
err := s.registerClose(e, false)
|
||||||
if err == errSessionAlreadyClosed {
|
if err == errSessionAlreadyClosed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -497,55 +521,42 @@ func (s *session) Close(e error) error {
|
|||||||
|
|
||||||
// close the connection. Use this when called from the run loop
|
// close the connection. Use this when called from the run loop
|
||||||
func (s *session) close(e error) error {
|
func (s *session) close(e error) error {
|
||||||
err := s.closeImpl(e, false)
|
err := s.registerClose(e, false)
|
||||||
if err == errSessionAlreadyClosed {
|
if err == errSessionAlreadyClosed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) closeImpl(e error, remoteClose bool) error {
|
func (s *session) handleCloseError(closeErr closeError) error {
|
||||||
// Only close once
|
var quicErr *qerr.QuicError
|
||||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
var ok bool
|
||||||
return errSessionAlreadyClosed
|
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
|
||||||
|
quicErr = qerr.ToQuicError(closeErr.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e == nil {
|
|
||||||
e = qerr.PeerGoingAway
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
s.closeChan <- e
|
|
||||||
}()
|
|
||||||
|
|
||||||
if e == errCloseSessionForNewVersion {
|
|
||||||
s.streamsMap.CloseWithError(e)
|
|
||||||
s.closeStreamsWithError(e)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
quicErr := qerr.ToQuicError(e)
|
|
||||||
|
|
||||||
// Don't log 'normal' reasons
|
// Don't log 'normal' reasons
|
||||||
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
|
if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout {
|
||||||
utils.Infof("Closing connection %x", s.connectionID)
|
utils.Infof("Closing connection %x", s.connectionID)
|
||||||
} else {
|
} else {
|
||||||
utils.Errorf("Closing session with error: %s", e.Error())
|
utils.Errorf("Closing session with error: %s", closeErr.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if closeErr.err == errCloseSessionForNewVersion {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.streamsMap.CloseWithError(quicErr)
|
s.streamsMap.CloseWithError(quicErr)
|
||||||
s.closeStreamsWithError(quicErr)
|
s.closeStreamsWithError(quicErr)
|
||||||
|
|
||||||
if remoteClose {
|
// If this is a remote close we're done here
|
||||||
// If this is a remote close we're done here
|
if closeErr.remote {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
|
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
|
||||||
return s.sendPublicReset(s.lastRcvdPacketNumber)
|
return s.sendPublicReset(s.lastRcvdPacketNumber)
|
||||||
}
|
}
|
||||||
s.sendConnectionClose(quicErr)
|
return s.sendConnectionClose(quicErr)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) closeStreamsWithError(err error) {
|
func (s *session) closeStreamsWithError(err error) {
|
||||||
|
|||||||
@@ -582,12 +582,15 @@ var _ = Describe("Session", func() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("handles CONNECTION_CLOSE frames", func() {
|
It("handles CONNECTION_CLOSE frames", func(done Done) {
|
||||||
|
go sess.run()
|
||||||
str, _ := sess.GetOrOpenStream(5)
|
str, _ := sess.GetOrOpenStream(5)
|
||||||
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
|
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Eventually(sess.runClosed).Should(BeClosed())
|
||||||
_, err = str.Read([]byte{0})
|
_, err = str.Read([]byte{0})
|
||||||
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
|
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
|
||||||
|
close(done)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("accepting streams", func() {
|
Context("accepting streams", func() {
|
||||||
@@ -621,16 +624,17 @@ var _ = Describe("Session", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("stops accepting when the session is closed after version negotiation", func() {
|
It("stops accepting when the session is closed after version negotiation", func() {
|
||||||
testErr := errCloseSessionForNewVersion
|
|
||||||
var err error
|
var err error
|
||||||
go func() {
|
go func() {
|
||||||
_, err = sess.AcceptStream()
|
_, err = sess.AcceptStream()
|
||||||
}()
|
}()
|
||||||
go sess.run()
|
go sess.run()
|
||||||
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
|
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
|
||||||
sess.Close(testErr)
|
Expect(sess.runClosed).ToNot(BeClosed())
|
||||||
|
sess.Close(errCloseSessionForNewVersion)
|
||||||
Eventually(func() error { return err }).Should(HaveOccurred())
|
Eventually(func() error { return err }).Should(HaveOccurred())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(errCloseSessionForNewVersion))
|
||||||
|
Eventually(sess.runClosed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1249,8 +1253,7 @@ var _ = Describe("Session", func() {
|
|||||||
go sess.run()
|
go sess.run()
|
||||||
sendUndecryptablePackets()
|
sendUndecryptablePackets()
|
||||||
Consistently(sess.undecryptablePackets).Should(BeEmpty())
|
Consistently(sess.undecryptablePackets).Should(BeEmpty())
|
||||||
sess.closeImpl(nil, true)
|
Expect(sess.Close(nil)).To(Succeed())
|
||||||
Eventually(sess.runClosed).Should(BeClosed())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("unqueues undecryptable packets for later decryption", func() {
|
It("unqueues undecryptable packets for later decryption", func() {
|
||||||
|
|||||||
Reference in New Issue
Block a user