wait until the run loop has stopped before returning Session.Close()

fixes #371
This commit is contained in:
Marten Seemann
2016-12-19 20:35:21 +07:00
parent facaab7ff7
commit 156c23f2b7
2 changed files with 34 additions and 10 deletions

View File

@@ -70,6 +70,7 @@ type Session struct {
// closeChan is used to notify the run loop that it should terminate.
// If the value is not nil, the error is sent as a CONNECTION_CLOSE.
closeChan chan *qerr.QuicError
runClosed chan struct{}
closed uint32 // atomic bool
undecryptablePackets []*receivedPacket
@@ -124,6 +125,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
sendingScheduled: make(chan struct{}, 1),
undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets),
aeadChanged: make(chan struct{}, 1),
runClosed: make(chan struct{}, 1), // this channel will receive once the run loop has been stopped
timer: time.NewTimer(0),
lastNetworkActivityTime: now,
@@ -151,10 +153,11 @@ func (s *Session) run() {
// Start the crypto stream handler
go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
s.Close(err)
s.close(err)
}
}()
runLoop:
for {
// Close immediately if requested
select {
@@ -162,7 +165,7 @@ func (s *Session) run() {
if errForConnClose != nil {
s.sendConnectionClose(errForConnClose)
}
return
break runLoop
default:
}
@@ -174,7 +177,7 @@ func (s *Session) run() {
if errForConnClose != nil {
s.sendConnectionClose(errForConnClose)
}
return
break runLoop
case <-s.timer.C:
s.timerRead = true
// We do all the interesting stuff after the switch statement, so
@@ -199,20 +202,23 @@ func (s *Session) run() {
}
if err != nil {
s.Close(err)
s.close(err)
}
if err := s.sendPacket(); err != nil {
s.Close(err)
s.close(err)
}
if time.Now().Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if !s.cryptoSetup.HandshakeComplete() && time.Now().Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake {
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time."))
s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time."))
}
s.garbageCollectStreams()
}
s.closeCallback(s.connectionID)
s.runClosed <- struct{}{}
}
func (s *Session) maybeResetTimer() {
@@ -407,7 +413,22 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
}
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning
func (s *Session) Close(e error) error {
err := s.closeImpl(e, false)
if atomic.LoadUint32(&s.closed) == 1 {
return err
}
select {
case <-s.runClosed:
return err
}
}
// close the connection. Use this when called from the run loop
func (s *Session) close(e error) error {
return s.closeImpl(e, false)
}
@@ -431,7 +452,6 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
}
s.closeStreamsWithError(quicErr)
s.closeCallback(s.connectionID)
if remoteClose {
// If this is a remote close we don't need to send a CONNECTION_CLOSE

View File

@@ -424,10 +424,10 @@ var _ = Describe("Session", func() {
It("shuts down without error", func() {
session.Close(nil)
Expect(closeCallbackCalled).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(closeCallbackCalled).To(BeTrue())
})
It("only closes once", func() {
@@ -442,8 +442,8 @@ var _ = Describe("Session", func() {
s, err := session.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
session.Close(testErr)
Expect(closeCallbackCalled).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(closeCallbackCalled).To(BeTrue())
n, err := s.Read([]byte{0})
Expect(n).To(BeZero())
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
@@ -757,6 +757,7 @@ var _ = Describe("Session", func() {
session.lastNetworkActivityTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
close(done)
})
@@ -764,6 +765,7 @@ var _ = Describe("Session", func() {
session.sessionCreationTime = time.Now().Add(-time.Hour)
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(closeCallbackCalled).To(BeTrue())
close(done)
})
@@ -773,6 +775,7 @@ var _ = Describe("Session", func() {
session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
close(done)
})
@@ -784,6 +787,7 @@ var _ = Describe("Session", func() {
session.packer.connectionParameters = session.connectionParameters
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
close(done)
})
})