From 524ecb5827324fb83b5eb30241442c0352091641 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 29 Aug 2017 23:58:27 +0700 Subject: [PATCH] move the Public Reset to the wire package --- client.go | 6 ++--- client_test.go | 8 +++---- .../wire/public_reset.go | 21 +++++++++------- .../wire/public_reset_test.go | 24 +++++++++---------- server.go | 8 +++---- server_test.go | 6 ++--- session.go | 2 +- 7 files changed, 39 insertions(+), 36 deletions(-) rename public_reset.go => internal/wire/public_reset.go (73%) rename public_reset_test.go => internal/wire/public_reset_test.go (79%) diff --git a/client.go b/client.go index b0e4d2557..9fb2b8255 100644 --- a/client.go +++ b/client.go @@ -244,13 +244,13 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { utils.Infof("Received a spoofed Public Reset. Ignoring.") return } - pr, err := parsePublicReset(r) + pr, err := wire.ParsePublicReset(r) if err != nil { utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") return } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber) - c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber))) + utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) return } diff --git a/client_test.go b/client_test.go index 9607427fb..9c99a62a6 100644 --- a/client_test.go +++ b/client_test.go @@ -386,27 +386,27 @@ var _ = Describe("Client", func() { Context("Public Reset handling", func() { It("closes the session when receiving a Public Reset", func() { - cl.handlePacket(addr, writePublicReset(cl.connectionID, 1, 0)) + cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closedRemote).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) It("ignores Public Resets with the wrong connection ID", func() { - cl.handlePacket(addr, writePublicReset(cl.connectionID+1, 1, 0)) + cl.handlePacket(addr, wire.WritePublicReset(cl.connectionID+1, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores Public Resets from the wrong remote address", func() { spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} - cl.handlePacket(spoofedAddr, writePublicReset(cl.connectionID, 1, 0)) + cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.connectionID, 1, 0)) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores unparseable Public Resets", func() { - pr := writePublicReset(cl.connectionID, 1, 0) + pr := wire.WritePublicReset(cl.connectionID, 1, 0) cl.handlePacket(addr, pr[:len(pr)-5]) Expect(cl.session.(*mockSession).closed).To(BeFalse()) Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) diff --git a/public_reset.go b/internal/wire/public_reset.go similarity index 73% rename from public_reset.go rename to internal/wire/public_reset.go index 163e3bbd5..7e8c162e5 100644 --- a/public_reset.go +++ b/internal/wire/public_reset.go @@ -1,4 +1,4 @@ -package quic +package wire import ( "bytes" @@ -10,12 +10,14 @@ import ( "github.com/lucas-clemente/quic-go/protocol" ) -type publicReset struct { - rejectedPacketNumber protocol.PacketNumber - nonce uint64 +// A PublicReset is a PUBLIC_RESET +type PublicReset struct { + RejectedPacketNumber protocol.PacketNumber + Nonce uint64 } -func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { +// WritePublicReset writes a Public Reset +func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) utils.LittleEndian.WriteUint64(b, uint64(connectionID)) @@ -30,8 +32,9 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p return b.Bytes() } -func parsePublicReset(r *bytes.Reader) (*publicReset, error) { - pr := publicReset{} +// ParsePublicReset parses a Public Reset +func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { + pr := PublicReset{} msg, err := handshake.ParseHandshakeMessage(r) if err != nil { return nil, err @@ -47,7 +50,7 @@ func parsePublicReset(r *bytes.Reader) (*publicReset, error) { if len(rseq) != 8 { return nil, errors.New("invalid RSEQ tag") } - pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) + pr.RejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) rnon, ok := msg.Data[handshake.TagRNON] if !ok { @@ -56,7 +59,7 @@ func parsePublicReset(r *bytes.Reader) (*publicReset, error) { if len(rnon) != 8 { return nil, errors.New("invalid RNON tag") } - pr.nonce = binary.LittleEndian.Uint64(rnon) + pr.Nonce = binary.LittleEndian.Uint64(rnon) return &pr, nil } diff --git a/public_reset_test.go b/internal/wire/public_reset_test.go similarity index 79% rename from public_reset_test.go rename to internal/wire/public_reset_test.go index f1344ee88..a8d751b13 100644 --- a/public_reset_test.go +++ b/internal/wire/public_reset_test.go @@ -1,4 +1,4 @@ -package quic +package wire import ( "bytes" @@ -13,7 +13,7 @@ import ( var _ = Describe("public reset", func() { Context("writing", func() { It("writes public reset packets", func() { - Expect(writePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad)).To(Equal([]byte{ + Expect(WritePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad)).To(Equal([]byte{ 0x0a, 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, 'P', 'R', 'S', 'T', @@ -36,21 +36,21 @@ var _ = Describe("public reset", func() { }) It("parses a public reset", func() { - packet := writePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad) - pr, err := parsePublicReset(bytes.NewReader(packet[9:])) // 1 byte Public Flag, 8 bytes connection ID + packet := WritePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad) + pr, err := ParsePublicReset(bytes.NewReader(packet[9:])) // 1 byte Public Flag, 8 bytes connection ID Expect(err).ToNot(HaveOccurred()) - Expect(pr.nonce).To(Equal(uint64(0xdecafbad))) - Expect(pr.rejectedPacketNumber).To(Equal(protocol.PacketNumber(0x8badf00d))) + Expect(pr.Nonce).To(Equal(uint64(0xdecafbad))) + Expect(pr.RejectedPacketNumber).To(Equal(protocol.PacketNumber(0x8badf00d))) }) It("rejects packets that it can't parse", func() { - _, err := parsePublicReset(bytes.NewReader([]byte{})) + _, err := ParsePublicReset(bytes.NewReader([]byte{})) Expect(err).To(MatchError(io.EOF)) }) It("rejects packets with the wrong tag", func() { handshake.HandshakeMessage{Tag: handshake.TagREJ, Data: nil}.Write(b) - _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + _, err := ParsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("wrong public reset tag")) }) @@ -59,7 +59,7 @@ var _ = Describe("public reset", func() { handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) - _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + _, err := ParsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("RNON missing")) }) @@ -69,7 +69,7 @@ var _ = Describe("public reset", func() { handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13}, } handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) - _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + _, err := ParsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("invalid RNON tag")) }) @@ -78,7 +78,7 @@ var _ = Describe("public reset", func() { handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) - _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + _, err := ParsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("RSEQ missing")) }) @@ -88,7 +88,7 @@ var _ = Describe("public reset", func() { handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) - _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + _, err := ParsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("invalid RSEQ tag")) }) }) diff --git a/server.go b/server.go index 4c7c6d3c1..e27f7f1c8 100644 --- a/server.go +++ b/server.go @@ -234,7 +234,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet hdr, err := wire.ParsePublicHeader(r, protocol.PerspectiveClient, version) if err == wire.ErrPacketWithUnknownVersion { - _, err = pconn.WriteTo(writePublicReset(connID, 0, 0), remoteAddr) + _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) return err } if err != nil { @@ -245,12 +245,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet // ignore all Public Reset packets if hdr.ResetFlag { if ok { - var pr *publicReset - pr, err = parsePublicReset(r) + var pr *wire.PublicReset + pr, err = wire.ParsePublicReset(r) if err != nil { utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") } else { - utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber) + utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) } } else { utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) diff --git a/server_test.go b/server_test.go index e0a1dcadb..e86634def 100644 --- a/server_test.go +++ b/server_test.go @@ -299,7 +299,7 @@ var _ = Describe("Server", func() { }) It("ignores public resets for unknown connections", func() { - err := serv.handlePacket(nil, nil, writePublicReset(999, 1, 1337)) + err := serv.handlePacket(nil, nil, wire.WritePublicReset(999, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(BeEmpty()) }) @@ -308,7 +308,7 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) - err = serv.handlePacket(nil, nil, writePublicReset(connID, 1, 1337)) + err = serv.handlePacket(nil, nil, wire.WritePublicReset(connID, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) @@ -318,7 +318,7 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1)) - data := writePublicReset(connID, 1, 1337) + data := wire.WritePublicReset(connID, 1, 1337) err = serv.handlePacket(nil, nil, data[:len(data)-2]) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) diff --git a/session.go b/session.go index 33a8e3dd6..61a14164d 100644 --- a/session.go +++ b/session.go @@ -794,7 +794,7 @@ func (s *session) garbageCollectStreams() { func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.Write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) + return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending