diff --git a/client.go b/client.go index 5cd4b52be..be0b78c4f 100644 --- a/client.go +++ b/client.go @@ -110,8 +110,8 @@ func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { } // Close closes the connection -func (c *Client) Close() error { - _ = c.session.Close(nil) +func (c *Client) Close(e error) error { + _ = c.session.Close(e) return c.conn.Close() } diff --git a/client_test.go b/client_test.go index f350b333b..23e2913c4 100644 --- a/client_test.go +++ b/client_test.go @@ -3,7 +3,9 @@ package quic import ( "bytes" "encoding/binary" + "errors" "net" + "runtime" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -59,6 +61,7 @@ var _ = Describe("Client", func() { }) It("properly closes the client", func(done Done) { + numGoRoutines := runtime.NumGoroutine() startUDPConn() var stoppedListening bool go func() { @@ -67,11 +70,13 @@ var _ = Describe("Client", func() { stoppedListening = true }() - err := client.Close() + testErr := errors.New("test error") + err := client.Close(testErr) Expect(err).ToNot(HaveOccurred()) Eventually(session.closed).Should(BeTrue()) - Expect(session.closeReason).To(BeNil()) + Expect(session.closeReason).To(MatchError(testErr)) Eventually(func() bool { return stoppedListening }).Should(BeTrue()) + Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines)) close(done) }) @@ -85,7 +90,7 @@ var _ = Describe("Client", func() { Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID)) Expect(client.session.(*Session).version).To(Equal(client.version)) - err = client.Close() + err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) }) @@ -130,7 +135,7 @@ var _ = Describe("Client", func() { Expect(session.closed).To(BeFalse()) Eventually(func() bool { return stoppedListening }).Should(BeFalse()) - err = client.Close() + err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) close(done) }) @@ -202,7 +207,7 @@ var _ = Describe("Client", func() { // it didn't pass the version negoation packet to the session (since it has no payload) Expect(session.packetCount).To(BeZero()) - err = client.Close() + err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/example/client/client.go b/example/client/client.go index d0966d82b..d6849c103 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -16,7 +16,7 @@ func main() { } err = client.Listen() - defer client.Close() + defer client.Close(nil) if err != nil { panic(err) } diff --git a/h2quic/client.go b/h2quic/client.go index d8bc3b0ed..5f12c9f2f 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -16,7 +16,7 @@ import ( type quicClient interface { OpenStream(protocol.StreamID) (utils.Stream, error) - Close() error + Close(error) error Listen() error } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index b30873928..f078663cc 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -14,8 +14,8 @@ type mockQuicClient struct { streams map[protocol.StreamID]*mockStream } -func (m *mockQuicClient) Close() error { panic("not implemented") } -func (m *mockQuicClient) Listen() error { panic("not implemented") } +func (m *mockQuicClient) Close(error) error { panic("not implemented") } +func (m *mockQuicClient) Listen() error { panic("not implemented") } func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) { _, ok := m.streams[id] if ok {