From 06b51871b1b7bc23966745ab14a3a814f2767a72 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 30 Apr 2016 10:04:45 +0700 Subject: [PATCH] close session when receiving a ConnectionCloseFrame fixes #28 --- session.go | 16 +++++++++++----- session_test.go | 6 +++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/session.go b/session.go index fefb5d21..eb0049d7 100644 --- a/session.go +++ b/session.go @@ -79,7 +79,7 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol go func() { if err := cryptoSetup.HandleCryptoStream(); err != nil { - session.Close(err) + session.Close(err, true) } }() @@ -109,11 +109,11 @@ func (s *Session) Run() { case ackhandler.ErrDuplicateOrOutOfOrderAck: // Can happen when RST_STREAMs arrive early or late (?) case ackhandler.ErrMapAccess: - s.Close(err) // TODO: sent correct error code here + s.Close(err, true) // TODO: sent correct error code here case errRstStreamOnInvalidStream: fmt.Printf("Ignoring error in session: %s\n", err.Error()) default: - s.Close(err) + s.Close(err, true) } } @@ -155,6 +155,7 @@ func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeade // ToDo: send right error in ConnectionClose frame case *frames.ConnectionCloseFrame: fmt.Printf("%#v\n", frame) + s.Close(nil, false) case *frames.StopWaitingFrame: err = s.receivedPacketHandler.ReceivedStopWaiting(frame) fmt.Printf("\t<- %#v\n", frame) @@ -218,13 +219,18 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { return nil } -// Close the connection by sending a ConnectionClose frame -func (s *Session) Close(e error) error { +// Close the connection +func (s *Session) Close(e error, sendConnectionClose bool) error { if s.closed { return nil } s.closed = true s.closeChan <- struct{}{} + + if !sendConnectionClose { + return nil + } + if e == nil { e = protocol.NewQuicError(errorcodes.QUIC_PEER_GOING_AWAY, "peer going away") } diff --git a/session_test.go b/session_test.go index f0ccb195..fdc2afc8 100644 --- a/session_test.go +++ b/session_test.go @@ -194,7 +194,7 @@ var _ = Describe("Session", func() { }) It("shuts down without error", func() { - session.Close(nil) + session.Close(nil, true) time.Sleep(1 * time.Millisecond) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore)) }) @@ -203,7 +203,7 @@ var _ = Describe("Session", func() { testErr := errors.New("test error") s, err := session.NewStream(5) Expect(err).NotTo(HaveOccurred()) - session.Close(testErr) + session.Close(testErr, true) time.Sleep(1 * time.Millisecond) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore)) n, err := s.Read([]byte{0}) @@ -285,7 +285,7 @@ var _ = Describe("Session", func() { err = session.handlePacket(nil, hdr, r) Expect(err).To(HaveOccurred()) // Close() should send public reset - err = session.Close(err) + err = session.Close(err, true) Expect(err).NotTo(HaveOccurred()) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0]).To(ContainSubstring(string([]byte("PRST"))))