diff --git a/client.go b/client.go index 3bf95b13..c1618bf3 100644 --- a/client.go +++ b/client.go @@ -284,28 +284,29 @@ func (c *client) listen() { } break } - c.handlePacket(addr, data[:n]) + if err := c.handlePacket(addr, data[:n]); err != nil { + c.logger.Errorf("error handling packet: %s", err.Error()) + } } } -func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { +func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) hdr, err := wire.ParseHeaderSentByServer(r, c.version) + // drop the packet if we can't parse the header if err != nil { - c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the header - return + return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) } // reject packets with truncated connection id if we didn't request truncation if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { - return + return errors.New("received packet with truncated connection ID, but didn't request truncation") } hdr.Raw = packet[:len(packet)-r.Len()] if hdr.IsLongHeader && !hdr.DestConnectionID.Equal(hdr.SrcConnectionID) { - c.logger.Errorf("receiving packets with different destination and source connection IDs not supported") + return fmt.Errorf("receiving packets with different destination and source connection IDs not supported") } c.mutex.Lock() @@ -314,7 +315,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { // reject packets with the wrong connection ID // TODO(#1003): add support for server-chosen connection IDs if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.connectionID) { - return + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.connectionID) } if hdr.ResetFlag { @@ -322,31 +323,29 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.connectionID) { - c.logger.Infof("Received a spoofed Public Reset. Ignoring.") - return + return errors.New("Received a spoofed Public Reset") } pr, err := wire.ParsePublicReset(r) if err != nil { - c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) - return + return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err) } - c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) - return + c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber) + return nil } // handle Version Negotiation Packets if hdr.IsVersionNegotiation { // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated { - return + return errors.New("received a delayed Version Negotiation Packet") } // version negotiation packets have no payload if err := c.handleVersionNegotiationPacket(hdr); err != nil { c.session.Close(err) } - return + return nil } // this is the first packet we are receiving @@ -364,6 +363,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { data: packet[len(packet)-r.Len():], rcvTime: rcvTime, }) + return nil } func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { @@ -389,7 +389,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { c.initialVersion = c.version c.version = newVersion var err error - c.connectionID, err = protocol.GenerateConnectionID() + c.connectionID, err = generateConnectionID() if err != nil { return err } diff --git a/client_test.go b/client_test.go index b816fb17..6b424b16 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "errors" + "fmt" "net" "os" "sync/atomic" @@ -302,7 +303,8 @@ var _ = Describe("Client", func() { b := &bytes.Buffer{} err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - cl.handlePacket(nil, b.Bytes()) + err = cl.handlePacket(nil, b.Bytes()) + Expect(err).ToNot(HaveOccurred()) Expect(cl.versionNegotiated).To(BeTrue()) Expect(cl.versionNegotiationChan).To(BeClosed()) }) @@ -392,15 +394,18 @@ var _ = Describe("Client", func() { go cl.dial() Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77})) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77})) + Expect(err).ToNot(HaveOccurred()) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78})) + err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78})) + Expect(err).To(MatchError("received a delayed Version Negotiation Packet")) Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) }) It("errors if no matching version is found", func() { cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) + Expect(err).ToNot(HaveOccurred()) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) @@ -409,7 +414,8 @@ var _ = Describe("Client", func() { v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) + Expect(err).ToNot(HaveOccurred()) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) @@ -417,29 +423,24 @@ var _ = Describe("Client", func() { It("changes to the version preferred by the quic.Config", func() { config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} cl.config = config - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) + Expect(err).ToNot(HaveOccurred()) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) - It("ignores delayed version negotiation packets", func() { - // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test - cl.versionNegotiated = true - Expect(sess.packetCount).To(BeZero()) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) - Expect(cl.versionNegotiated).To(BeTrue()) - Expect(sess.packetCount).To(BeZero()) - }) - It("drops version negotiation packets that contain the offered version", func() { ver := cl.version - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) + Expect(err).ToNot(HaveOccurred()) Expect(cl.version).To(Equal(ver)) }) }) }) It("ignores packets with an invalid public header", func() { - cl.handlePacket(addr, []byte("invalid packet")) + err := cl.handlePacket(addr, []byte("invalid packet")) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("error parsing packet from")) Expect(sess.packetCount).To(BeZero()) Expect(sess.closed).To(BeFalse()) }) @@ -447,27 +448,33 @@ var _ = Describe("Client", func() { It("ignores packets without connection id, if it didn't request connection id trunctation", func() { cl.config = &Config{RequestConnectionIDOmission: false} buf := &bytes.Buffer{} - (&wire.Header{ + err := (&wire.Header{ OmitConnectionID: true, + SrcConnectionID: connID, + DestConnectionID: connID, PacketNumber: 1, PacketNumberLen: 1, - }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) - cl.handlePacket(addr, buf.Bytes()) + }).Write(buf, protocol.PerspectiveServer, versionGQUICFrames) + Expect(err).ToNot(HaveOccurred()) + err = cl.handlePacket(addr, buf.Bytes()) + Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation")) Expect(sess.packetCount).To(BeZero()) Expect(sess.closed).To(BeFalse()) }) It("ignores packets with the wrong connection ID", func() { buf := &bytes.Buffer{} - connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} Expect(connID).ToNot(Equal(connID2)) - (&wire.Header{ + err := (&wire.Header{ DestConnectionID: connID2, SrcConnectionID: connID2, PacketNumber: 1, PacketNumberLen: 1, }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) - cl.handlePacket(addr, buf.Bytes()) + Expect(err).ToNot(HaveOccurred()) + err = cl.handlePacket(addr, buf.Bytes()) + Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID))) Expect(sess.packetCount).To(BeZero()) Expect(sess.closed).To(BeFalse()) }) @@ -626,30 +633,26 @@ var _ = Describe("Client", func() { Context("Public Reset handling", func() { It("closes the session when receiving a Public Reset", func() { - cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID, 1, 0)) + err := cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID, 1, 0)) + Expect(err).ToNot(HaveOccurred()) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closedRemote).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) - It("ignores Public Resets with the wrong connection ID", func() { - connID2 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} - Expect(connID).ToNot(Equal(connID2)) - cl.handlePacket(addr, wire.WritePublicReset(connID2, 1, 0)) - Expect(cl.session.(*mockSession).closed).To(BeFalse()) - Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) - }) - It("ignores Public Resets from the wrong remote address", func() { spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} - cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.connectionID, 1, 0)) + err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.connectionID, 1, 0)) + Expect(err).To(MatchError("Received a spoofed Public Reset")) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores unparseable Public Resets", func() { pr := wire.WritePublicReset(cl.connectionID, 1, 0) - cl.handlePacket(addr, pr[:len(pr)-5]) + err := cl.handlePacket(addr, pr[:len(pr)-5]) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet")) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) })