forked from quic-go/quic-go
Merge pull request #373 from lucas-clemente/fix-371
wait until the run loop has stopped before returning Session.Close()
This commit is contained in:
43
session.go
43
session.go
@@ -32,6 +32,7 @@ type receivedPacket struct {
|
||||
var (
|
||||
errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream")
|
||||
errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream")
|
||||
errSessionAlreadyClosed = errors.New("Cannot close Session. It was already closed before.")
|
||||
)
|
||||
|
||||
// StreamCallback gets a stream frame and returns a reply frame
|
||||
@@ -70,6 +71,7 @@ type Session struct {
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
// If the value is not nil, the error is sent as a CONNECTION_CLOSE.
|
||||
closeChan chan *qerr.QuicError
|
||||
runClosed chan struct{}
|
||||
closed uint32 // atomic bool
|
||||
|
||||
undecryptablePackets []*receivedPacket
|
||||
@@ -124,6 +126,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
sendingScheduled: make(chan struct{}, 1),
|
||||
undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets),
|
||||
aeadChanged: make(chan struct{}, 1),
|
||||
runClosed: make(chan struct{}, 1), // this channel will receive once the run loop has been stopped
|
||||
|
||||
timer: time.NewTimer(0),
|
||||
lastNetworkActivityTime: now,
|
||||
@@ -155,6 +158,7 @@ func (s *Session) run() {
|
||||
}
|
||||
}()
|
||||
|
||||
runLoop:
|
||||
for {
|
||||
// Close immediately if requested
|
||||
select {
|
||||
@@ -162,7 +166,7 @@ func (s *Session) run() {
|
||||
if errForConnClose != nil {
|
||||
s.sendConnectionClose(errForConnClose)
|
||||
}
|
||||
return
|
||||
break runLoop
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -174,7 +178,7 @@ func (s *Session) run() {
|
||||
if errForConnClose != nil {
|
||||
s.sendConnectionClose(errForConnClose)
|
||||
}
|
||||
return
|
||||
break runLoop
|
||||
case <-s.timer.C:
|
||||
s.timerRead = true
|
||||
// We do all the interesting stuff after the switch statement, so
|
||||
@@ -199,20 +203,23 @@ func (s *Session) run() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.Close(err)
|
||||
s.close(err)
|
||||
}
|
||||
|
||||
if err := s.sendPacket(); err != nil {
|
||||
s.Close(err)
|
||||
s.close(err)
|
||||
}
|
||||
if time.Now().Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
|
||||
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
if !s.cryptoSetup.HandshakeComplete() && time.Now().Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake {
|
||||
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time."))
|
||||
s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time."))
|
||||
}
|
||||
s.garbageCollectStreams()
|
||||
}
|
||||
|
||||
s.closeCallback(s.connectionID)
|
||||
s.runClosed <- struct{}{}
|
||||
}
|
||||
|
||||
func (s *Session) maybeResetTimer() {
|
||||
@@ -407,14 +414,31 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
}
|
||||
|
||||
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *Session) Close(e error) error {
|
||||
return s.closeImpl(e, false)
|
||||
err := s.closeImpl(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wait for the run loop to finish
|
||||
<-s.runClosed
|
||||
return err
|
||||
}
|
||||
|
||||
// close the connection. Use this when called from the run loop
|
||||
func (s *Session) close(e error) error {
|
||||
err := s.closeImpl(e, false)
|
||||
if err == errSessionAlreadyClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Session) closeImpl(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return nil
|
||||
return errSessionAlreadyClosed
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
@@ -431,7 +455,6 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
|
||||
}
|
||||
|
||||
s.closeStreamsWithError(quicErr)
|
||||
s.closeCallback(s.connectionID)
|
||||
|
||||
if remoteClose {
|
||||
// If this is a remote close we don't need to send a CONNECTION_CLOSE
|
||||
@@ -651,7 +674,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) {
|
||||
}
|
||||
utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber)
|
||||
if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets {
|
||||
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
}
|
||||
s.undecryptablePackets = append(s.undecryptablePackets, p)
|
||||
}
|
||||
|
||||
@@ -424,10 +424,11 @@ var _ = Describe("Session", func() {
|
||||
|
||||
It("shuts down without error", func() {
|
||||
session.Close(nil)
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
})
|
||||
|
||||
It("only closes once", func() {
|
||||
@@ -435,6 +436,7 @@ var _ = Describe("Session", func() {
|
||||
session.Close(nil)
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
})
|
||||
|
||||
It("closes streams with proper error", func() {
|
||||
@@ -442,14 +444,15 @@ var _ = Describe("Session", func() {
|
||||
s, err := session.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.Close(testErr)
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
n, err := s.Read([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
|
||||
n, err = s.Write([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
|
||||
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -727,6 +730,7 @@ var _ = Describe("Session", func() {
|
||||
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
|
||||
Expect(session.runClosed).To(Receive())
|
||||
})
|
||||
|
||||
It("ignores undecryptable packets after the handshake is complete", func() {
|
||||
@@ -740,6 +744,7 @@ var _ = Describe("Session", func() {
|
||||
go session.run()
|
||||
Consistently(session.undecryptablePackets).Should(HaveLen(0))
|
||||
session.closeImpl(nil, true)
|
||||
Eventually(session.runClosed).Should(Receive())
|
||||
})
|
||||
|
||||
It("unqueues undecryptable packets for later decryption", func() {
|
||||
@@ -757,6 +762,8 @@ var _ = Describe("Session", func() {
|
||||
session.lastNetworkActivityTime = time.Now().Add(-time.Hour)
|
||||
session.run() // Would normally not return
|
||||
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(session.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -764,6 +771,8 @@ var _ = Describe("Session", func() {
|
||||
session.sessionCreationTime = time.Now().Add(-time.Hour)
|
||||
session.run() // Would normally not return
|
||||
Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(session.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -773,6 +782,8 @@ var _ = Describe("Session", func() {
|
||||
session.packer.connectionParameters = session.connectionParameters
|
||||
session.run() // Would normally not return
|
||||
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(session.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -784,6 +795,8 @@ var _ = Describe("Session", func() {
|
||||
session.packer.connectionParameters = session.connectionParameters
|
||||
session.run() // Would normally not return
|
||||
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Expect(session.runClosed).To(Receive())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user