diff --git a/client.go b/client.go index ee2f6a3cd..63efd46ec 100644 --- a/client.go +++ b/client.go @@ -7,9 +7,7 @@ import ( "errors" "fmt" "net" - "strings" "sync" - "time" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -21,6 +19,7 @@ import ( type client struct { mutex sync.Mutex + pconn net.PacketConn conn connection hostname string @@ -83,15 +82,7 @@ func DialAddrContext( if err != nil { return nil, err } - c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil) - if err != nil { - return nil, err - } - go c.listen() - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.session, nil + return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config) } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -129,6 +120,11 @@ func DialContext( if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil { return nil, err } + if config.RequestConnectionIDOmission { + if err := multiplexer.AddHandler(pconn, protocol.ConnectionID{}, c); err != nil { + return nil, err + } + } if err := c.dial(ctx); err != nil { return nil, err } @@ -168,6 +164,7 @@ func newClient( onClose = closeCallback } c := &client{ + pconn: pconn, conn: &conn{pconn: pconn, currentAddr: remoteAddr}, hostname: hostname, tlsConf: tlsConf, @@ -350,58 +347,6 @@ func (c *client) establishSecureConnection(ctx context.Context) error { } } -// Listen listens on the underlying connection and passes packets on for handling. -// It returns when the connection is closed. -func (c *client) listen() { - var err error - - for { - var n int - var addr net.Addr - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] - // The packet size should not exceed protocol.MaxReceivePacketSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - n, addr, err = c.conn.Read(data) - if err != nil { - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - c.mutex.Lock() - if c.session != nil { - c.session.Close(err) - } - c.mutex.Unlock() - } - break - } - c.handleRead(addr, data[:n]) - } -} - -func (c *client) handleRead(remoteAddr net.Addr, packet []byte) { - rcvTime := time.Now() - - r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength) - // drop the packet if we can't parse the header - if err != nil { - c.logger.Errorf("error parsing invariant header: %s", err) - return - } - hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, c.version) - if err != nil { - c.logger.Errorf("error parsing header: %s", err) - return - } - hdr.Raw = packet[:len(packet)-r.Len()] - packetData := packet[len(packet)-r.Len():] - c.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) -} - func (c *client) handlePacket(p *receivedPacket) { if err := c.handlePacketImpl(p); err != nil { c.logger.Errorf("error handling packet: %s", err) @@ -524,6 +469,9 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { c.initialVersion = c.version c.version = newVersion c.generateConnectionIDs() + if err := getClientMultiplexer().AddHandler(c.pconn, c.srcConnID, c); err != nil { + return err + } c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.session.Close(errCloseSessionForNewVersion) diff --git a/client_test.go b/client_test.go index f653efec1..329689bb9 100644 --- a/client_test.go +++ b/client_test.go @@ -45,6 +45,17 @@ var _ = Describe("Client", func() { return b.Bytes() } + composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket { + return &receivedPacket{ + rcvTime: time.Now(), + header: &wire.Header{ + IsVersionNegotiation: true, + DestConnectionID: connID, + SupportedVersions: versions, + }, + } + } + BeforeEach(func() { connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession @@ -95,6 +106,10 @@ var _ = Describe("Client", func() { }) It("resolves the address", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) + if os.Getenv("APPVEYOR") == "True" { Skip("This test is flaky on AppVeyor.") } @@ -122,6 +137,10 @@ var _ = Describe("Client", func() { }) It("uses the tls.Config.ServerName as the hostname, if present", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) + hostnameChan := make(chan string, 1) newClientSession = func( _ connection, @@ -466,6 +485,8 @@ var _ = Describe("Client", func() { }) It("changes the version after receiving a version negotiation packet", func() { + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) + version1 := protocol.Version39 version2 := protocol.Version39 + 1 Expect(version2.UsesTLS()).To(BeFalse()) @@ -502,11 +523,13 @@ var _ = Describe("Client", func() { close(dialed) }() Eventually(sessionChan).Should(HaveLen(1)) - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2})) Eventually(sessionChan).Should(BeEmpty()) }) It("only accepts one version negotiation packet", func() { + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) + version1 := protocol.Version39 version2 := protocol.Version39 + 1 version3 := protocol.Version39 + 2 @@ -545,10 +568,10 @@ var _ = Describe("Client", func() { close(dialed) }() Eventually(sessionChan).Should(HaveLen(1)) - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2})) Eventually(sessionChan).Should(BeEmpty()) Expect(cl.version).To(Equal(version2)) - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version3})) Eventually(dialed).Should(BeClosed()) Expect(cl.version).To(Equal(version2)) }) @@ -558,7 +581,7 @@ var _ = Describe("Client", func() { sess.EXPECT().Close(gomock.Any()) cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1})) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { @@ -568,23 +591,24 @@ var _ = Describe("Client", func() { v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v})) }) It("changes to the version preferred by the quic.Config", func() { + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().Close(errCloseSessionForNewVersion) cl.session = sess - config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} - cl.config = config - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) + versions := []protocol.VersionNumber{1234, 4321} + cl.config = &Config{Versions: versions} + cl.handlePacket(composeVersionNegotiationPacket(connID, versions)) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) It("drops version negotiation packets that contain the offered version", func() { cl.config = &Config{} ver := cl.version - cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) }) }) @@ -595,12 +619,6 @@ var _ = Describe("Client", func() { Expect(cl.GetVersion()).To(Equal(cl.version)) }) - It("ignores packets with an invalid public header", func() { - cl.config = &Config{} - cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls - cl.handleRead(addr, []byte("invalid packet")) - }) - It("errors on packets that are smaller than the Payload Length in the packet header", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls hdr := &wire.Header{ @@ -829,47 +847,6 @@ var _ = Describe("Client", func() { Eventually(done).Should(BeClosed()) }) - Context("handling packets", func() { - It("handles packets", func() { - cl.config = &Config{} - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - cl.session = sess - ph := wire.Header{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - DestConnectionID: connID, - } - b := &bytes.Buffer{} - err := ph.Write(b, protocol.PerspectiveServer, cl.version) - Expect(err).ToNot(HaveOccurred()) - packetConn.dataToRead <- b.Bytes() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cl.listen() - // it should continue listening when receiving valid packets - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - // make the go routine return - sess.EXPECT().Close(gomock.Any()) - Expect(packetConn.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("closes the session when encountering an error while reading from the connection", func() { - testErr := errors.New("test error") - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().Close(testErr) - cl.session = sess - packetConn.readErr = testErr - cl.listen() - }) - }) - Context("Public Reset handling", func() { var ( pr []byte @@ -895,7 +872,11 @@ var _ = Describe("Client", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) cl.session = sess - cl.handleRead(addr, wire.WritePublicReset(cl.destConnID, 1, 0)) + cl.handlePacketImpl(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: pr[len(pr)-hdrLen:], + }) }) It("ignores Public Resets from the wrong remote address", func() {