From 8fd2ddf81cfe8da965276780084af7068a21c71a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 22 Feb 2017 15:16:59 +0700 Subject: [PATCH] properly close the client --- client.go | 22 +++++----------------- client_test.go | 34 +++++++++++----------------------- 2 files changed, 16 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index eef01db92..9aa4fd8b1 100644 --- a/client.go +++ b/client.go @@ -7,7 +7,6 @@ import ( "net" "strings" "sync" - "sync/atomic" "time" "github.com/lucas-clemente/quic-go/protocol" @@ -27,7 +26,6 @@ type client struct { connectionID protocol.ConnectionID version protocol.VersionNumber versionNegotiated bool - closed uint32 // atomic bool tlsConfig *tls.Config cryptoChangeCallback CryptoChangeCallback @@ -83,7 +81,7 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) - go c.Listen() + go c.listen() c.mutex.Lock() for !c.versionNegotiated { @@ -110,7 +108,7 @@ func DialAddr(hostname string, config *Config) (Session, error) { } // Listen listens -func (c *client) Listen() { +func (c *client) listen() { for { data := getPacketBuffer() data = data[:protocol.MaxPacketSize] @@ -133,17 +131,6 @@ func (c *client) Listen() { } } -// Close closes the connection -func (c *client) Close(e error) error { - // Only close once - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return nil - } - - _ = c.session.Close(e) - return c.conn.Close() -} - func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { return qerr.PacketTooLarge @@ -249,6 +236,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e return nil } -func (c *client) closeCallback(id protocol.ConnectionID) { - utils.Infof("Connection %x closed.", id) +func (c *client) closeCallback(_ protocol.ConnectionID) { + utils.Infof("Connection %x closed.", c.connectionID) + c.conn.Close() } diff --git a/client_test.go b/client_test.go index 70f9a7add..b9f0e3923 100644 --- a/client_test.go +++ b/client_test.go @@ -67,35 +67,29 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(qerr.PacketTooLarge)) }) - PIt("properly closes the client", func(done Done) { + // this test requires a real session (because it calls the close callback) and a real UDP conn (because it unblocks and errors when it is closed) + It("properly closes", func(done Done) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + Expect(err).ToNot(HaveOccurred()) + cl.conn = &conn{pconn: udpConn} + err = cl.createNewSession(nil) testErr := errors.New("test error") time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish numGoRoutines := runtime.NumGoroutine() var stoppedListening bool go func() { - cl.Listen() + cl.listen() stoppedListening = true }() - err := cl.Close(testErr) + err = cl.session.Close(testErr) Expect(err).ToNot(HaveOccurred()) - Eventually(sess.closed).Should(BeTrue()) - Expect(sess.closeReason).To(MatchError(testErr)) - Expect(cl.closed).To(Equal(uint32(1))) Eventually(func() bool { return stoppedListening }).Should(BeTrue()) Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines)) close(done) }, 10) - It("only closes the client once", func() { - cl.closed = 1 - err := cl.Close(errors.New("test error")) - Expect(err).ToNot(HaveOccurred()) - Eventually(sess.closed).Should(BeFalse()) - Expect(sess.closeReason).ToNot(HaveOccurred()) - }) - It("creates new sessions with the right parameters", func() { cl.session = nil cl.hostname = "hostname" @@ -104,9 +98,6 @@ var _ = Describe("Client", func() { Expect(cl.session).ToNot(BeNil()) Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID)) Expect(cl.session.(*session).version).To(Equal(cl.version)) - - err = cl.Close(nil) - Expect(err).ToNot(HaveOccurred()) }) Context("handling packets", func() { @@ -129,7 +120,7 @@ var _ = Describe("Client", func() { Expect(sess.packetCount).To(BeZero()) var stoppedListening bool go func() { - cl.Listen() + cl.listen() // it should continue listening when receiving valid packets stoppedListening = true }() @@ -142,7 +133,7 @@ var _ = Describe("Client", func() { It("closes the session when encountering an error while handling a packet", func() { Expect(sess.closeReason).ToNot(HaveOccurred()) packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100) - cl.Listen() + cl.listen() Expect(sess.closed).To(BeTrue()) Expect(sess.closeReason).To(HaveOccurred()) }) @@ -150,7 +141,7 @@ var _ = Describe("Client", func() { It("closes the session when encountering an error while reading from the connection", func() { testErr := errors.New("test error") packetConn.readErr = testErr - cl.Listen() + cl.listen() Expect(sess.closed).To(BeTrue()) Expect(sess.closeReason).To(MatchError(testErr)) }) @@ -204,9 +195,6 @@ var _ = Describe("Client", func() { // it didn't pass the version negoation packet to the session (since it has no payload) Expect(sess.packetCount).To(BeZero()) Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) - - err = cl.Close(nil) - Expect(err).ToNot(HaveOccurred()) }) It("errors if no matching version is found", func() {