diff --git a/client.go b/client.go index be0b78c4f..1b6756407 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "math/rand" "net" "strings" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go/protocol" @@ -22,6 +23,7 @@ type Client struct { connectionID protocol.ConnectionID version protocol.VersionNumber versionNegotiated bool + closed uint32 // atomic bool cryptoChangeCallback CryptoChangeCallback versionNegotiateCallback VersionNegotiateCallback @@ -111,6 +113,11 @@ func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { // Close closes the connection func (c *Client) Close(e error) error { + // Only close once + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return nil + } + _ = c.session.Close(e) return c.conn.Close() } diff --git a/client_test.go b/client_test.go index be2bb95eb..4941c30ed 100644 --- a/client_test.go +++ b/client_test.go @@ -77,11 +77,20 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) Eventually(session.closed).Should(BeTrue()) Expect(session.closeReason).To(MatchError(testErr)) + Expect(client.closed).To(Equal(uint32(1))) Eventually(func() bool { return stoppedListening }).Should(BeTrue()) Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines)) close(done) }) + It("only closes the client once", func() { + client.closed = 1 + err := client.Close(errors.New("test error")) + Expect(err).ToNot(HaveOccurred()) + Eventually(session.closed).Should(BeFalse()) + Expect(session.closeReason).ToNot(HaveOccurred()) + }) + It("creates new sessions with the right parameters", func() { startUDPConn() client.session = nil