diff --git a/client.go b/client.go index 5e1d2f519..8ec242347 100644 --- a/client.go +++ b/client.go @@ -22,7 +22,7 @@ type Client struct { connectionID protocol.ConnectionID version protocol.VersionNumber - session *Session + session packetHandler } var errHostname = errors.New("Invalid hostname") @@ -97,6 +97,12 @@ func (c *Client) Listen() error { } } +// Close closes the connection +func (c *Client) Close() error { + _ = c.session.Close(nil) + return c.conn.Close() +} + func (c *Client) handlePacket(packet []byte) error { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { return qerr.PacketTooLarge @@ -123,5 +129,4 @@ func (c *Client) handlePacket(packet []byte) error { func (c *Client) closeCallback(id protocol.ConnectionID) { utils.Infof("Connection %x closed.", id) - c.conn.Close() } diff --git a/client_test.go b/client_test.go index bff1a8099..37118cd8b 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "net" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -26,4 +27,15 @@ var _ = Describe("Client", func() { err := client.handlePacket(bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1)) Expect(err).To(MatchError(qerr.PacketTooLarge)) }) + + It("closes sessions when Close is called", func() { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + client.conn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + client.session = &mockSession{} + err = client.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(client.session.(*mockSession).closed).To(BeTrue()) + }) }) diff --git a/example/client/client.go b/example/client/client.go index 24b03dbd9..4898fb61c 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -16,6 +16,7 @@ func main() { } err = client.Listen() + defer client.Close() if err != nil { panic(err) }