forked from quic-go/quic-go
properly close the client
This commit is contained in:
22
client.go
22
client.go
@@ -7,7 +7,6 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
@@ -27,7 +26,6 @@ type client struct {
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
versionNegotiated bool
|
||||
closed uint32 // atomic bool
|
||||
|
||||
tlsConfig *tls.Config
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
@@ -83,7 +81,7 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
go c.Listen()
|
||||
go c.listen()
|
||||
|
||||
c.mutex.Lock()
|
||||
for !c.versionNegotiated {
|
||||
@@ -110,7 +108,7 @@ func DialAddr(hostname string, config *Config) (Session, error) {
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
func (c *client) Listen() {
|
||||
func (c *client) listen() {
|
||||
for {
|
||||
data := getPacketBuffer()
|
||||
data = data[:protocol.MaxPacketSize]
|
||||
@@ -133,17 +131,6 @@ func (c *client) Listen() {
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *client) Close(e error) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = c.session.Close(e)
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
|
||||
return qerr.PacketTooLarge
|
||||
@@ -249,6 +236,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) closeCallback(id protocol.ConnectionID) {
|
||||
utils.Infof("Connection %x closed.", id)
|
||||
func (c *client) closeCallback(_ protocol.ConnectionID) {
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
@@ -67,35 +67,29 @@ var _ = Describe("Client", func() {
|
||||
Expect(err).To(MatchError(qerr.PacketTooLarge))
|
||||
})
|
||||
|
||||
PIt("properly closes the client", func(done Done) {
|
||||
// this test requires a real session (because it calls the close callback) and a real UDP conn (because it unblocks and errors when it is closed)
|
||||
It("properly closes", func(done Done) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.conn = &conn{pconn: udpConn}
|
||||
err = cl.createNewSession(nil)
|
||||
testErr := errors.New("test error")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish
|
||||
numGoRoutines := runtime.NumGoroutine()
|
||||
|
||||
var stoppedListening bool
|
||||
go func() {
|
||||
cl.Listen()
|
||||
cl.listen()
|
||||
stoppedListening = true
|
||||
}()
|
||||
|
||||
err := cl.Close(testErr)
|
||||
err = cl.session.Close(testErr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(sess.closed).Should(BeTrue())
|
||||
Expect(sess.closeReason).To(MatchError(testErr))
|
||||
Expect(cl.closed).To(Equal(uint32(1)))
|
||||
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
|
||||
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
|
||||
close(done)
|
||||
}, 10)
|
||||
|
||||
It("only closes the client once", func() {
|
||||
cl.closed = 1
|
||||
err := cl.Close(errors.New("test error"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(sess.closed).Should(BeFalse())
|
||||
Expect(sess.closeReason).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("creates new sessions with the right parameters", func() {
|
||||
cl.session = nil
|
||||
cl.hostname = "hostname"
|
||||
@@ -104,9 +98,6 @@ var _ = Describe("Client", func() {
|
||||
Expect(cl.session).ToNot(BeNil())
|
||||
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
|
||||
Expect(cl.session.(*session).version).To(Equal(cl.version))
|
||||
|
||||
err = cl.Close(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
@@ -129,7 +120,7 @@ var _ = Describe("Client", func() {
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
var stoppedListening bool
|
||||
go func() {
|
||||
cl.Listen()
|
||||
cl.listen()
|
||||
// it should continue listening when receiving valid packets
|
||||
stoppedListening = true
|
||||
}()
|
||||
@@ -142,7 +133,7 @@ var _ = Describe("Client", func() {
|
||||
It("closes the session when encountering an error while handling a packet", func() {
|
||||
Expect(sess.closeReason).ToNot(HaveOccurred())
|
||||
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
|
||||
cl.Listen()
|
||||
cl.listen()
|
||||
Expect(sess.closed).To(BeTrue())
|
||||
Expect(sess.closeReason).To(HaveOccurred())
|
||||
})
|
||||
@@ -150,7 +141,7 @@ var _ = Describe("Client", func() {
|
||||
It("closes the session when encountering an error while reading from the connection", func() {
|
||||
testErr := errors.New("test error")
|
||||
packetConn.readErr = testErr
|
||||
cl.Listen()
|
||||
cl.listen()
|
||||
Expect(sess.closed).To(BeTrue())
|
||||
Expect(sess.closeReason).To(MatchError(testErr))
|
||||
})
|
||||
@@ -204,9 +195,6 @@ var _ = Describe("Client", func() {
|
||||
// it didn't pass the version negoation packet to the session (since it has no payload)
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
|
||||
|
||||
err = cl.Close(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
|
||||
Reference in New Issue
Block a user