From 79642d502e3828f0539720a21fb170afb45c8c9b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 11 Jul 2017 21:09:04 +0700 Subject: [PATCH] don't close the client connection when the Public Header can't be parsed When receiving an unparseable packet with a spoofed remote address, we should not close the connection. --- client.go | 27 +++++++++++------------ client_test.go | 59 +++++++++++++++----------------------------------- 2 files changed, 31 insertions(+), 55 deletions(-) diff --git a/client.go b/client.go index 97f03bc6..2e18de80 100644 --- a/client.go +++ b/client.go @@ -211,22 +211,19 @@ func (c *client) listen() { } data = data[:n] - err = c.handlePacket(addr, data) - if err != nil { - utils.Errorf("error handling packet: %s", err.Error()) - c.session.Close(err) - break - } + c.handlePacket(addr, data) } } -func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { +func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) if err != nil { - return qerr.Error(qerr.InvalidPacketHeader, err.Error()) + utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) + // drop this packet if we can't parse the Public Header + return } hdr.Raw = packet[:len(packet)-r.Len()] @@ -239,21 +236,21 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { utils.Infof("Received a spoofed Public Reset. Ignoring.") - return nil + return } pr, err := parsePublicReset(r) if err != nil { utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") - return nil + return } utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber))) - return nil + return } // ignore delayed / duplicated version negotiation packets if c.versionNegotiated && hdr.VersionFlag { - return nil + return } // this is the first packet after the client sent a packet with the VersionFlag set @@ -264,7 +261,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { if hdr.VersionFlag { // version negotiation packets have no payload - return c.handlePacketWithVersionFlag(hdr) + if err := c.handlePacketWithVersionFlag(hdr); err != nil { + c.session.Close(err) + } + return } c.session.handlePacket(&receivedPacket{ @@ -273,7 +273,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { data: packet[len(packet)-r.Len():], rcvTime: rcvTime, }) - return nil } func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { diff --git a/client_test.go b/client_test.go index ce5d1856..1be1ba28 100644 --- a/client_test.go +++ b/client_test.go @@ -196,13 +196,6 @@ var _ = Describe("Client", func() { Expect(c.RequestConnectionIDTruncation).To(BeFalse()) }) - It("errors when receiving an invalid first packet from the server", func(done Done) { - packetConn.dataToRead = []byte{0xff} - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).To(HaveOccurred()) - close(done) - }) - It("errors when receiving an error from the connection", func(done Done) { testErr := errors.New("connection error") packetConn.readErr = testErr @@ -238,8 +231,7 @@ var _ = Describe("Client", func() { b := &bytes.Buffer{} err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) - err = cl.handlePacket(nil, b.Bytes()) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(nil, b.Bytes()) Expect(cl.versionNegotiated).To(BeTrue()) }) @@ -265,35 +257,34 @@ var _ = Describe("Client", func() { Expect(newVersion).ToNot(Equal(cl.version)) Expect(sess.packetCount).To(BeZero()) cl.connectionID = 0x1337 - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) Expect(cl.version).To(Equal(newVersion)) Expect(cl.versionNegotiated).To(BeTrue()) // it swapped the sessions // Expect(cl.session).ToNot(Equal(sess)) Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID - Expect(err).ToNot(HaveOccurred()) // it didn't pass the version negoation packet to the old session (since it has no payload) Expect(sess.packetCount).To(BeZero()) Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion})) }) It("errors if no matching version is found", func() { - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) - Expect(err).To(MatchError(qerr.InvalidVersion)) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + Expect(cl.session.(*mockSession).closed).To(BeTrue()) + Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { v := protocol.SupportedVersions[1] Expect(v).ToNot(Equal(cl.version)) Expect(config.Versions).ToNot(ContainElement(v)) - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) - Expect(err).To(MatchError(qerr.InvalidVersion)) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) + Expect(cl.session.(*mockSession).closed).To(BeTrue()) + Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("changes to the version preferred by the quic.Config", func() { - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) Expect(cl.version).To(Equal(config.Versions[1])) }) @@ -301,24 +292,22 @@ var _ = Describe("Client", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test cl.versionNegotiated = true Expect(sess.packetCount).To(BeZero()) - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) Expect(cl.versionNegotiated).To(BeTrue()) Expect(sess.packetCount).To(BeZero()) }) It("drops version negotiation packets that contain the offered version", func() { ver := cl.version - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) }) }) }) - It("errors on invalid public header", func() { - err := cl.handlePacket(nil, nil) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) + It("ignores packets with an invalid public header", func() { + cl.handlePacket(addr, []byte("invalid packet")) + Expect(cl.session.(*mockSession).closed).To(BeFalse()) }) It("creates new sessions with the right parameters", func(done Done) { @@ -380,14 +369,6 @@ var _ = Describe("Client", func() { Consistently(func() bool { return stoppedListening }).Should(BeFalse()) }) - It("closes the session when encountering an error while handling a packet", func() { - Expect(sess.closeReason).ToNot(HaveOccurred()) - packetConn.dataToRead = []byte("invalid packet") - cl.listen() - Expect(sess.closed).To(BeTrue()) - Expect(sess.closeReason).To(HaveOccurred()) - }) - It("closes the session when encountering an error while reading from the connection", func() { testErr := errors.New("test error") packetConn.readErr = testErr @@ -399,32 +380,28 @@ var _ = Describe("Client", func() { Context("Public Reset handling", func() { It("closes the session when receiving a Public Reset", func() { - err := cl.handlePacket(addr, writePublicReset(cl.connectionID, 1, 0)) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(addr, writePublicReset(cl.connectionID, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closedRemote).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) It("ignores Public Resets with the wrong connection ID", func() { - err := cl.handlePacket(addr, writePublicReset(cl.connectionID+1, 1, 0)) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(addr, writePublicReset(cl.connectionID+1, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores Public Resets from the wrong remote address", func() { spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} - err := cl.handlePacket(spoofedAddr, writePublicReset(cl.connectionID, 1, 0)) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(spoofedAddr, writePublicReset(cl.connectionID, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores unparseable Public Resets", func() { pr := writePublicReset(cl.connectionID, 1, 0) - err := cl.handlePacket(addr, pr[:len(pr)-5]) - Expect(err).ToNot(HaveOccurred()) + cl.handlePacket(addr, pr[:len(pr)-5]) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) })