From e39251c8b595fb198d2953fe23315b49689f53f7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 14 Jul 2018 22:34:52 +0200 Subject: [PATCH] implement writing and parsing of Initial packets containing a Token --- internal/wire/header.go | 24 ++++++++- internal/wire/header_parser.go | 16 +++++- internal/wire/header_parser_test.go | 22 +++++++++ internal/wire/header_test.go | 77 +++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/internal/wire/header.go b/internal/wire/header.go index 99efab75..9df77a23 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -37,6 +37,7 @@ type Header struct { IsLongHeader bool KeyPhase int PayloadLen protocol.ByteCount + Token []byte } var errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes") @@ -65,7 +66,14 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error { b.WriteByte(connIDLen) b.Write(h.DestConnectionID.Bytes()) b.Write(h.SrcConnectionID.Bytes()) + + if h.Type == protocol.PacketTypeInitial { + utils.WriteVarInt(b, uint64(len(h.Token))) + b.Write(h.Token) + } + utils.WriteVarInt(b, uint64(h.PayloadLen)) + return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) } @@ -149,7 +157,11 @@ func (h *Header) GetLength(version protocol.VersionNumber) (protocol.ByteCount, func (h *Header) getHeaderLength() (protocol.ByteCount, error) { if h.IsLongHeader { - return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil + length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen) + if h.Type == protocol.PacketTypeInitial { + length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length, nil } length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) @@ -194,7 +206,15 @@ func (h *Header) logHeader(logger utils.Logger) { if h.Version == 0 { logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions) } else { - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version) + var token string + if h.Type == protocol.PacketTypeInitial { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version) } } else { logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) diff --git a/internal/wire/header_parser.go b/internal/wire/header_parser.go index e712f76a..bcb43b86 100644 --- a/internal/wire/header_parser.go +++ b/internal/wire/header_parser.go @@ -123,6 +123,21 @@ func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Head func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) { h := iv.toHeader() + h.Type = protocol.PacketType(iv.typeByte & 0x7f) + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + if tokenLen > uint64(b.Len()) { + return nil, io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err + } + } pl, err := utils.ReadVarInt(b) if err != nil { @@ -135,7 +150,6 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) { } h.PacketNumber = pn h.PacketNumberLen = pnLen - h.Type = protocol.PacketType(iv.typeByte & 0x7f) if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake { return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) diff --git a/internal/wire/header_parser_test.go b/internal/wire/header_parser_test.go index 769abf1f..bef49119 100644 --- a/internal/wire/header_parser_test.go +++ b/internal/wire/header_parser_test.go @@ -81,6 +81,8 @@ var _ = Describe("Header Parsing", func() { } data = append(data, destConnID...) data = append(data, srcConnID...) + data = append(data, encodeVarInt(6)...) // token length + data = append(data, []byte("foobar")...) // token data = append(data, encodeVarInt(0x1337)...) // payload length // packet number data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4) @@ -97,6 +99,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.IsLongHeader).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) + Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(hdr.PayloadLen).To(Equal(protocol.ByteCount(0x1337))) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) @@ -143,8 +146,10 @@ var _ = Describe("Header Parsing", func() { 0x1, 0x2, 0x3, 0x4, // version number 0x0, // connection ID lengths } + data = append(data, encodeVarInt(0)...) // token length data = append(data, encodeVarInt(0x42)...) // payload length data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2) + b := bytes.NewReader(data) iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) @@ -173,6 +178,23 @@ var _ = Describe("Header Parsing", func() { Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42")) }) + It("errors if the token length is too large", func() { + data := []byte{ + 0x80 ^ uint8(protocol.PacketTypeInitial), + 0x1, 0x2, 0x3, 0x4, // version number + 0x0, // connection ID lengths + } + data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) + data = append(data, encodeVarInt(0x42)...) // payload length, 1 byte + data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2) // 2 bytes + + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).To(MatchError(io.EOF)) + }) + It("errors on EOF, when parsing the invariant header", func() { data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 0ac2987c..b0b2bbab 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -97,6 +97,21 @@ var _ = Describe("Header", func() { Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}))) }) + + It("writes an Initial containing a token", func() { + token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + err := (&Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + Version: 0x1020304, + }).Write(buf, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) + Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) + }) }) Context("short header", func() { @@ -388,6 +403,39 @@ var _ = Describe("Header", func() { Expect(buf.Len()).To(Equal(expectedLen)) }) + It("has the right length for an Initial not containing a Token", func() { + h := &Header{ + Type: protocol.PacketTypeInitial, + IsLongHeader: true, + PayloadLen: 1500, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + PacketNumberLen: protocol.PacketNumberLen2, + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* long payload len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + err := h.Write(buf, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + + It("has the right length for an Initial containing a Token", func() { + h := &Header{ + Type: protocol.PacketTypeInitial, + IsLongHeader: true, + PayloadLen: 1500, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + PacketNumberLen: protocol.PacketNumberLen2, + Token: []byte("foo"), + } + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long payload len */ + 2 /* packet number */ + Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) + err := h.Write(buf, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(buf.Len()).To(Equal(expectedLen)) + }) + It("has the right length for a Short Header containing a connection ID", func() { h := &Header{ PacketNumberLen: protocol.PacketNumberLen1, @@ -546,6 +594,35 @@ var _ = Describe("Header", func() { Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, PacketNumberLen: 2, PayloadLen: 54321, Version: 0xfeed}")) }) + It("logs Initial Packets with a Token", func() { + (&Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + PayloadLen: 100, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Version: 0xfeed, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: 0xcafe1337, SrcConnectionID: 0xdecafbad, Token: 0xdeadbeef, PacketNumber: 0x42, PacketNumberLen: 2, PayloadLen: 100, Version: 0xfeed}")) + }) + + It("logs Initial Packets without a Token", func() { + (&Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + PayloadLen: 100, + DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + Version: 0xfeed, + }).Log(logger) + Expect(buf.String()).To(ContainSubstring("Long Header{Type: Initial, DestConnectionID: 0xcafe1337, SrcConnectionID: 0xdecafbad, Token: (empty), PacketNumber: 0x42, PacketNumberLen: 2, PayloadLen: 100, Version: 0xfeed}")) + }) + It("logs Short Headers containing a connection ID", func() { (&Header{ KeyPhase: 1,