properly close the client

This commit is contained in:
Marten Seemann
2017-02-22 15:16:59 +07:00
parent 8247454b0f
commit 8fd2ddf81c
2 changed files with 16 additions and 40 deletions

View File

@@ -7,7 +7,6 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/protocol"
@@ -27,7 +26,6 @@ type client struct {
connectionID protocol.ConnectionID
version protocol.VersionNumber
versionNegotiated bool
closed uint32 // atomic bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
@@ -83,7 +81,7 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
go c.Listen()
go c.listen()
c.mutex.Lock()
for !c.versionNegotiated {
@@ -110,7 +108,7 @@ func DialAddr(hostname string, config *Config) (Session, error) {
}
// Listen listens
func (c *client) Listen() {
func (c *client) listen() {
for {
data := getPacketBuffer()
data = data[:protocol.MaxPacketSize]
@@ -133,17 +131,6 @@ func (c *client) Listen() {
}
}
// Close closes the connection
func (c *client) Close(e error) error {
// Only close once
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
}
_ = c.session.Close(e)
return c.conn.Close()
}
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
@@ -249,6 +236,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
return nil
}
func (c *client) closeCallback(id protocol.ConnectionID) {
utils.Infof("Connection %x closed.", id)
func (c *client) closeCallback(_ protocol.ConnectionID) {
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}

View File

@@ -67,35 +67,29 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(qerr.PacketTooLarge))
})
PIt("properly closes the client", func(done Done) {
// this test requires a real session (because it calls the close callback) and a real UDP conn (because it unblocks and errors when it is closed)
It("properly closes", func(done Done) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
Expect(err).ToNot(HaveOccurred())
cl.conn = &conn{pconn: udpConn}
err = cl.createNewSession(nil)
testErr := errors.New("test error")
time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish
numGoRoutines := runtime.NumGoroutine()
var stoppedListening bool
go func() {
cl.Listen()
cl.listen()
stoppedListening = true
}()
err := cl.Close(testErr)
err = cl.session.Close(testErr)
Expect(err).ToNot(HaveOccurred())
Eventually(sess.closed).Should(BeTrue())
Expect(sess.closeReason).To(MatchError(testErr))
Expect(cl.closed).To(Equal(uint32(1)))
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
close(done)
}, 10)
It("only closes the client once", func() {
cl.closed = 1
err := cl.Close(errors.New("test error"))
Expect(err).ToNot(HaveOccurred())
Eventually(sess.closed).Should(BeFalse())
Expect(sess.closeReason).ToNot(HaveOccurred())
})
It("creates new sessions with the right parameters", func() {
cl.session = nil
cl.hostname = "hostname"
@@ -104,9 +98,6 @@ var _ = Describe("Client", func() {
Expect(cl.session).ToNot(BeNil())
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
Expect(cl.session.(*session).version).To(Equal(cl.version))
err = cl.Close(nil)
Expect(err).ToNot(HaveOccurred())
})
Context("handling packets", func() {
@@ -129,7 +120,7 @@ var _ = Describe("Client", func() {
Expect(sess.packetCount).To(BeZero())
var stoppedListening bool
go func() {
cl.Listen()
cl.listen()
// it should continue listening when receiving valid packets
stoppedListening = true
}()
@@ -142,7 +133,7 @@ var _ = Describe("Client", func() {
It("closes the session when encountering an error while handling a packet", func() {
Expect(sess.closeReason).ToNot(HaveOccurred())
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
cl.Listen()
cl.listen()
Expect(sess.closed).To(BeTrue())
Expect(sess.closeReason).To(HaveOccurred())
})
@@ -150,7 +141,7 @@ var _ = Describe("Client", func() {
It("closes the session when encountering an error while reading from the connection", func() {
testErr := errors.New("test error")
packetConn.readErr = testErr
cl.Listen()
cl.listen()
Expect(sess.closed).To(BeTrue())
Expect(sess.closeReason).To(MatchError(testErr))
})
@@ -204,9 +195,6 @@ var _ = Describe("Client", func() {
// it didn't pass the version negoation packet to the session (since it has no payload)
Expect(sess.packetCount).To(BeZero())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
err = cl.Close(nil)
Expect(err).ToNot(HaveOccurred())
})
It("errors if no matching version is found", func() {