diff --git a/h2quic/server.go b/h2quic/server.go index cff0a62e..eb416bd8 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -16,7 +16,7 @@ import ( type streamCreator interface { GetOrOpenStream(protocol.StreamID) (utils.Stream, error) - Close(error, bool) error + Close(error) error } // Server is a HTTP2 server listening for QUIC connections @@ -111,7 +111,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, } if s.CloseAfterFirstRequest { time.Sleep(100 * time.Millisecond) - session.Close(nil, true) + session.Close(nil) } }() diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 25b0efbf..1d71f159 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -22,7 +22,7 @@ func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error return &mockStream{}, nil } -func (s *mockSession) Close(error, bool) error { s.closed = true; return nil } +func (s *mockSession) Close(error) error { s.closed = true; return nil } var _ = Describe("H2 server", func() { var ( diff --git a/session.go b/session.go index 9ac0c782..47a7b84b 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go/ackhandler" @@ -58,7 +59,7 @@ type Session struct { receivedPackets chan receivedPacket sendingScheduled chan struct{} closeChan chan struct{} - closed bool + closed uint32 // atomic bool undecryptablePackets []receivedPacket aeadChanged chan struct{} @@ -102,7 +103,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol go func() { if err := cryptoSetup.HandleCryptoStream(); err != nil { - session.Close(err, true) + session.Close(err) } }() @@ -160,22 +161,20 @@ func (s *Session) run() { case <-s.aeadChanged: s.tryDecryptingQueuedPackets() case <-time.After(s.connectionParametersManager.GetIdleConnectionStateLifetime()): - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."), true) + s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } if err != nil { switch err { - // Can happen e.g. when packets thought missing arrive late case ackhandler.ErrDuplicateOrOutOfOrderAck: - // Can happen when RST_STREAMs arrive early or late (?) - case ackhandler.ErrMapAccess: - s.Close(err, true) // TODO: sent correct error code here + // Can happen e.g. when packets thought missing arrive late case errRstStreamOnInvalidStream: + // Can happen when RST_STREAMs arrive early or late (?) utils.Errorf("Ignoring error in session: %s", err.Error()) - // Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream case errWindowUpdateOnClosedStream: + // Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream default: - s.Close(err, true) + s.Close(err) } } @@ -216,9 +215,8 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: - // ToDo: send right error in ConnectionClose frame utils.Debugf("\t<- %#v", frame) - s.Close(nil, false) + s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) case *frames.StopWaitingFrame: utils.Debugf("\t<- %#v", frame) err = s.receivedPacketHandler.ReceivedStopWaiting(frame) @@ -348,25 +346,28 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error { } // Close the connection -func (s *Session) Close(e error, sendConnectionClose bool) error { - if s.closed { +func (s *Session) Close(e error) error { + return s.closeImpl(e, false) +} + +func (s *Session) closeImpl(e error, remoteClose bool) error { + // Only close once + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return nil } - s.closed = true s.closeChan <- struct{}{} - s.closeCallback(s.connectionID) - - if !sendConnectionClose { - return nil - } - if e == nil { e = qerr.PeerGoingAway } utils.Errorf("Closing session with error: %s", e.Error()) s.closeStreamsWithError(e) + s.closeCallback(s.connectionID) + + if remoteClose { + return nil + } quicErr := qerr.ToQuicError(e) if quicErr.ErrorCode == qerr.DecryptionFailure { @@ -622,7 +623,7 @@ func (s *Session) congestionAllowsSending() bool { func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) { utils.Debugf("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets { - s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"), true) + s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } s.undecryptablePackets = append(s.undecryptablePackets, p) } diff --git a/session_test.go b/session_test.go index d39f4064..ceb6b13b 100644 --- a/session_test.go +++ b/session_test.go @@ -5,6 +5,7 @@ import ( "errors" "io" "runtime" + "sync/atomic" "time" . "github.com/onsi/ginkgo" @@ -328,16 +329,25 @@ var _ = Describe("Session", func() { }) It("shuts down without error", func() { - session.Close(nil, true) + session.Close(nil) Expect(closed).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) + Expect(conn.written).To(HaveLen(1)) + Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) + }) + + It("only closes once", func() { + session.Close(nil) + session.Close(nil) + Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) + Expect(conn.written).To(HaveLen(1)) }) It("closes streams with proper error", func() { testErr := errors.New("test error") s, err := session.OpenStream(5) Expect(err).NotTo(HaveOccurred()) - session.Close(testErr, true) + session.Close(testErr) Expect(closed).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) n, err := s.Read([]byte{0}) @@ -554,7 +564,7 @@ var _ = Describe("Session", func() { Data: []byte("4242\x00\x00\x00\x00"), }) Expect(err).NotTo(HaveOccurred()) - Eventually(func() bool { return session.closed }).Should(BeTrue()) + Eventually(func() bool { return atomic.LoadUint32(&session.closed) != 0 }).Should(BeTrue()) _, err = s.Write([]byte{}) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) })