fix detection if the Session was already closed before

This commit is contained in:
Marten Seemann
2016-12-20 10:04:42 +07:00
parent 156c23f2b7
commit 0de6368bbb
2 changed files with 23 additions and 11 deletions

View File

@@ -32,6 +32,7 @@ type receivedPacket struct {
var (
errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream")
errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream")
errSessionAlreadyClosed = errors.New("Cannot close Session. It was already closed before.")
)
// StreamCallback gets a stream frame and returns a reply frame
@@ -153,7 +154,7 @@ func (s *Session) run() {
// Start the crypto stream handler
go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
s.close(err)
s.Close(err)
}
}()
@@ -416,26 +417,28 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
// It waits until the run loop has stopped before returning
func (s *Session) Close(e error) error {
err := s.closeImpl(e, false)
if atomic.LoadUint32(&s.closed) == 1 {
return err
if err == errSessionAlreadyClosed {
return nil
}
select {
case <-s.runClosed:
return err
}
// wait for the run loop to finish
<-s.runClosed
return err
}
// close the connection. Use this when called from the run loop
func (s *Session) close(e error) error {
return s.closeImpl(e, false)
err := s.closeImpl(e, false)
if err == errSessionAlreadyClosed {
return nil
}
return err
}
func (s *Session) closeImpl(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return nil
return errSessionAlreadyClosed
}
if e == nil {
@@ -671,7 +674,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) {
}
utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber)
if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets {
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
}
s.undecryptablePackets = append(s.undecryptablePackets, p)
}

View File

@@ -428,6 +428,7 @@ var _ = Describe("Session", func() {
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
It("only closes once", func() {
@@ -435,6 +436,7 @@ var _ = Describe("Session", func() {
session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
It("closes streams with proper error", func() {
@@ -450,6 +452,7 @@ var _ = Describe("Session", func() {
n, err = s.Write([]byte{0})
Expect(n).To(BeZero())
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close()
})
})
@@ -727,6 +730,7 @@ var _ = Describe("Session", func() {
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST"))))
Expect(session.runClosed).To(Receive())
})
It("ignores undecryptable packets after the handshake is complete", func() {
@@ -740,6 +744,7 @@ var _ = Describe("Session", func() {
go session.run()
Consistently(session.undecryptablePackets).Should(HaveLen(0))
session.closeImpl(nil, true)
Eventually(session.runClosed).Should(Receive())
})
It("unqueues undecryptable packets for later decryption", func() {
@@ -758,6 +763,7 @@ var _ = Describe("Session", func() {
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
})
@@ -766,6 +772,7 @@ var _ = Describe("Session", func() {
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
})
@@ -776,6 +783,7 @@ var _ = Describe("Session", func() {
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
})
@@ -788,6 +796,7 @@ var _ = Describe("Session", func() {
session.run() // Would normally not return
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(closeCallbackCalled).To(BeTrue())
Expect(session.runClosed).To(Receive())
close(done)
})
})