From f53055b9a0c2dd745a0e29504460b40e52069f66 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 23 Feb 2017 11:56:53 +0700 Subject: [PATCH] return packet handling and connection errors in Dial --- client.go | 32 ++++++++++++++++++++-------- client_test.go | 58 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 60 insertions(+), 30 deletions(-) diff --git a/client.go b/client.go index c3766509..d2b4661d 100644 --- a/client.go +++ b/client.go @@ -14,8 +14,9 @@ import ( ) type client struct { - mutex sync.Mutex - connStateChangeCond sync.Cond + mutex sync.Mutex + connStateChangeOrErrCond sync.Cond + listenErr error conn connection hostname string @@ -55,7 +56,7 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default } - c.connStateChangeCond.L = &c.mutex + c.connStateChangeOrErrCond.L = &c.mutex err = c.createNewSession(nil) if err != nil { @@ -67,16 +68,20 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config go c.listen() c.mutex.Lock() + defer c.mutex.Unlock() + for { + if c.listenErr != nil { + return nil, c.listenErr + } if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated { break } if c.config.ConnState == nil && c.connState == ConnStateForwardSecure { break } - c.connStateChangeCond.Wait() + c.connStateChangeOrErrCond.Wait() } - c.mutex.Unlock() return c.session, nil } @@ -98,16 +103,20 @@ func DialAddr(hostname string, config *Config) (Session, error) { // Listen listens func (c *client) listen() { + var err error + for { + var n int + var addr net.Addr data := getPacketBuffer() data = data[:protocol.MaxPacketSize] - n, addr, err := c.conn.Read(data) + n, addr, err = c.conn.Read(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { c.session.Close(err) } - return + break } data = data[:n] @@ -115,9 +124,14 @@ func (c *client) listen() { if err != nil { utils.Errorf("error handling packet: %s", err.Error()) c.session.Close(err) - return + break } } + + c.mutex.Lock() + c.listenErr = err + c.connStateChangeOrErrCond.Signal() + c.mutex.Unlock() } func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { @@ -145,7 +159,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { if !hdr.VersionFlag && c.connState == ConnStateInitial { c.mutex.Lock() c.connState = ConnStateVersionNegotiated - c.connStateChangeCond.Signal() + c.connStateChangeOrErrCond.Signal() c.mutex.Unlock() if c.config.ConnState != nil { go c.config.ConnState(c.session, ConnStateVersionNegotiated) diff --git a/client_test.go b/client_test.go index 32b20c6c..f6c6a2d9 100644 --- a/client_test.go +++ b/client_test.go @@ -48,28 +48,44 @@ var _ = Describe("Client", func() { } }) - It("creates a new client", func() { - packetConn.dataToRead = []byte{0x0, 0x1, 0x0} - var err error - sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) - Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) - Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io")) - }) - - // TODO: actually test this - // now we're only testing that Dial doesn't return directly after version negotiation - It("only returns once a forward-secure connection is established if no ConnState is defined", func() { - packetConn.dataToRead = []byte{0x0, 0x1, 0x0} - config.ConnState = nil - var dialReturned bool - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Context("Dialing", func() { + It("creates a new client", func() { + packetConn.dataToRead = []byte{0x0, 0x1, 0x0} + var err error + sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) Expect(err).ToNot(HaveOccurred()) - dialReturned = true - }() - Consistently(func() bool { return dialReturned }).Should(BeFalse()) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) + Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io")) + }) + + It("errors when receiving an invalid first packet from the server", func() { + packetConn.dataToRead = []byte{0xff} + sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).To(HaveOccurred()) + Expect(sess).To(BeNil()) + }) + + It("errors when receiving an error from the connection", func() { + testErr := errors.New("connection error") + packetConn.readErr = testErr + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).To(MatchError(testErr)) + }) + + // TODO: actually test this + // now we're only testing that Dial doesn't return directly after version negotiation + It("only returns once a forward-secure connection is established if no ConnState is defined", func() { + packetConn.dataToRead = []byte{0x0, 0x1, 0x0} + config.ConnState = nil + var dialReturned bool + go func() { + defer GinkgoRecover() + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).ToNot(HaveOccurred()) + dialReturned = true + }() + Consistently(func() bool { return dialReturned }).Should(BeFalse()) + }) }) It("errors on invalid public header", func() {