From a8db148dbf9c525bdd2cb21f419240e5b636d527 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 27 Nov 2018 13:39:44 +0700 Subject: [PATCH] don't lock the client mutex when handling regular packets --- client.go | 15 ++++++++------- client_test.go | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index f057331d..c4fae8a6 100644 --- a/client.go +++ b/client.go @@ -27,7 +27,7 @@ type client struct { token []byte - versionNegotiated bool // has the server accepted our version + versionNegotiated utils.AtomicBool // has the server accepted our version receivedVersionNegotiationPacket bool negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet @@ -295,28 +295,29 @@ func (c *client) handlePacket(p *receivedPacket) { } func (c *client) handlePacketImpl(p *receivedPacket) error { - c.mutex.Lock() - defer c.mutex.Unlock() - // handle Version Negotiation Packets if p.hdr.IsVersionNegotiation() { + c.mutex.Lock() err := c.handleVersionNegotiationPacket(p.hdr) if err != nil { c.session.destroy(err) } + c.mutex.Unlock() // version negotiation packets have no payload return err } if p.hdr.Type == protocol.PacketTypeRetry { + c.mutex.Lock() c.handleRetryPacket(p.hdr) + c.mutex.Unlock() return nil } // this is the first packet we are receiving // since it is not a Version Negotiation Packet, this means the server supports the suggested version - if !c.versionNegotiated { - c.versionNegotiated = true + if !c.versionNegotiated.Get() { + c.versionNegotiated.Set(true) } c.session.handlePacket(p) @@ -325,7 +326,7 @@ func (c *client) handlePacketImpl(p *receivedPacket) error { func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { // ignore delayed / duplicated version negotiation packets - if c.receivedVersionNegotiationPacket || c.versionNegotiated { + if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { c.logger.Debugf("Received a delayed Version Negotiation Packet.") return nil } diff --git a/client_test.go b/client_test.go index 5cd8b1ab..3d07e2a8 100644 --- a/client_test.go +++ b/client_test.go @@ -687,7 +687,7 @@ var _ = Describe("Client", func() { }, }) Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated).To(BeTrue()) + Expect(cl.versionNegotiated.Get()).To(BeTrue()) }) It("errors if no matching version is found", func() {