diff --git a/public_reset.go b/public_reset.go index 7cceb2e6..b1f60d41 100644 --- a/public_reset.go +++ b/public_reset.go @@ -2,12 +2,19 @@ package quic import ( "bytes" + "encoding/binary" + "errors" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" ) +type publicReset struct { + rejectedPacketNumber protocol.PacketNumber + nonce uint64 +} + func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) @@ -22,3 +29,34 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p utils.WriteUint64(b, uint64(rejectedPacketNumber)) return b.Bytes() } + +func parsePublicReset(r *bytes.Reader) (*publicReset, error) { + pr := publicReset{} + tag, tagMap, err := handshake.ParseHandshakeMessage(r) + if err != nil { + return nil, err + } + if tag != handshake.TagPRST { + return nil, errors.New("wrong public reset tag") + } + + rseq, ok := tagMap[handshake.TagRSEQ] + if !ok { + return nil, errors.New("RSEQ missing") + } + if len(rseq) != 8 { + return nil, errors.New("invalid RSEQ tag") + } + pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) + + rnon, ok := tagMap[handshake.TagRNON] + if !ok { + return nil, errors.New("RNON missing") + } + if len(rnon) != 8 { + return nil, errors.New("invalid RNON tag") + } + pr.nonce = binary.LittleEndian.Uint64(rnon) + + return &pr, nil +} diff --git a/public_reset_test.go b/public_reset_test.go index 446b4e7d..0df859b5 100644 --- a/public_reset_test.go +++ b/public_reset_test.go @@ -1,6 +1,11 @@ package quic import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -22,4 +27,69 @@ var _ = Describe("public reset", func() { })) }) }) + + Context("parsing", func() { + var b *bytes.Buffer + + BeforeEach(func() { + b = &bytes.Buffer{} + }) + + It("parses a public reset", func() { + packet := writePublicReset(0xdeadbeef, 0x8badf00d, 0xdecafbad) + pr, err := parsePublicReset(bytes.NewReader(packet[9:])) // 1 byte Public Flag, 8 bytes connection ID + Expect(err).ToNot(HaveOccurred()) + Expect(pr.nonce).To(Equal(uint64(0xdecafbad))) + Expect(pr.rejectedPacketNumber).To(Equal(protocol.PacketNumber(0x8badf00d))) + }) + + It("rejects packets that it can't parse", func() { + _, err := parsePublicReset(bytes.NewReader([]byte{})) + Expect(err).To(MatchError(io.EOF)) + }) + + It("rejects packets with the wrong tag", func() { + handshake.WriteHandshakeMessage(b, handshake.TagREJ, nil) + _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + Expect(err).To(MatchError("wrong public reset tag")) + }) + + It("rejects packets missing the nonce", func() { + data := map[handshake.Tag][]byte{ + handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + } + handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + Expect(err).To(MatchError("RNON missing")) + }) + + It("rejects packets with a wrong length nonce", func() { + data := map[handshake.Tag][]byte{ + handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13}, + } + handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + Expect(err).To(MatchError("invalid RNON tag")) + }) + + It("rejects packets missing the rejected packet number", func() { + data := map[handshake.Tag][]byte{ + handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + } + handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + Expect(err).To(MatchError("RSEQ missing")) + }) + + It("rejects packets with a wrong length rejected packet number", func() { + data := map[handshake.Tag][]byte{ + handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13}, + handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + } + handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + _, err := parsePublicReset(bytes.NewReader(b.Bytes())) + Expect(err).To(MatchError("invalid RSEQ tag")) + }) + }) })