diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 130cb1bff..7c4d60716 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -73,3 +73,6 @@ const DefaultMaxAckDelay = 25 * time.Millisecond // MaxMaxAckDelay is the maximum max_ack_delay const MaxMaxAckDelay = 1 << 14 * time.Millisecond + +// MaxConnIDLen is the maximum length of the connection ID +const MaxConnIDLen = 20 diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 4ab3951dc..5b85e573f 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -83,6 +83,15 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { // Write writes the Header. func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { + if h.DestConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + } + if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + } + if h.OrigDestConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.OrigDestConnectionID.Len()) + } if h.IsLongHeader { return h.writeLongHeader(b, ver) } @@ -102,28 +111,21 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumb packetType = 0x3 } firstByte := 0xc0 | packetType<<4 - if h.Type == protocol.PacketTypeRetry { - odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID) - if err != nil { - return err - } - firstByte |= odcil - } else { // Retry packets don't have a packet number + if h.Type != protocol.PacketTypeRetry { + // Retry packets don't have a packet number firstByte |= uint8(h.PacketNumberLen - 1) } b.WriteByte(firstByte) utils.BigEndian.WriteUint32(b, uint32(h.Version)) - connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) - if err != nil { - return err - } - b.WriteByte(connIDLen) + b.WriteByte(uint8(h.DestConnectionID.Len())) b.Write(h.DestConnectionID.Bytes()) + b.WriteByte(uint8(h.SrcConnectionID.Len())) b.Write(h.SrcConnectionID.Bytes()) switch h.Type { case protocol.PacketTypeRetry: + b.WriteByte(uint8(h.OrigDestConnectionID.Len())) b.Write(h.OrigDestConnectionID.Bytes()) b.Write(h.Token) return nil @@ -159,7 +161,7 @@ func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { // GetLength determines the length of the Header. func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { if h.IsLongHeader { - length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length)) + length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length)) if h.Type == protocol.PacketTypeInitial { length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) } @@ -191,26 +193,3 @@ func (h *ExtendedHeader) Log(logger utils.Logger) { logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } - -func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { - dcil, err := encodeSingleConnIDLen(dest) - if err != nil { - return 0, err - } - scil, err := encodeSingleConnIDLen(src) - if err != nil { - return 0, err - } - return scil | dcil<<4, nil -} - -func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { - len := id.Len() - if len == 0 { - return 0, nil - } - if len < 4 || len > 18 { - return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) - } - return byte(len - 3), nil -} diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 86f85d84d..e4a0b2010 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -40,8 +40,9 @@ var _ = Describe("Header", func() { expected := []byte{ 0xc0 | 0x2<<4 | 0x2, 0x1, 0x2, 0x3, 0x4, // version number - 0x35, // connection ID lengths + 0x6, // dest connection ID length 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID + 0x8, // src connection ID length 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID } expected = append(expected, encodeVarInt(0xcafe)...) // length @@ -49,42 +50,27 @@ var _ = Describe("Header", func() { Expect(buf.Bytes()).To(Equal(expected)) }) - It("refuses to write a header with a too short connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3}, // connection IDs must be at least 4 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 3 bytes")) - }) - It("refuses to write a header with a too long connection ID", func() { err := (&ExtendedHeader{ Header: Header{ IsLongHeader: true, SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, // connection IDs must be at most 18 bytes long + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long Version: 0x1020304, Type: 0x5, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 19 bytes")) + Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) }) - It("writes a header with an 18 byte connection ID", func() { + It("writes a header with a 20 byte connection ID", func() { err := (&ExtendedHeader{ Header: Header{ IsLongHeader: true, SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // connection IDs must be at most 18 bytes long + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long Version: 0x1020304, Type: 0x5, }, @@ -92,7 +78,7 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen4, }).Write(buf, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) - Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}))) + Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) }) It("writes an Initial containing a token", func() { @@ -121,9 +107,11 @@ var _ = Describe("Header", func() { OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, }}).Write(buf, versionIETFHeader)).To(Succeed()) expected := []byte{ - 0xc0 | 0x3<<4 | 9 - 3, /* orig dest connection ID length */ - 0x1, 0x2, 0x3, 0x4, // version number - 0x0, // connection ID lengths)) + 0xc0 | 0x3<<4, + 0x1, 0x2, 0x3, 0x4, // version number + 0x0, // dest connection ID length + 0x0, // src connection ID length + 0x9, // orig dest connection ID length 1, 2, 3, 4, 5, 6, 7, 8, 9, // Orig Dest Connection ID } expected = append(expected, token...) @@ -136,9 +124,9 @@ var _ = Describe("Header", func() { Version: 0x1020304, Type: protocol.PacketTypeRetry, Token: []byte("foobar"), - OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, // connection IDs must be at most 18 bytes long + OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long }}).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 19 bytes")) + Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) }) }) @@ -229,7 +217,7 @@ var _ = Describe("Header", func() { }, PacketNumberLen: protocol.PacketNumberLen1, } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 1 /* short len */ + 1 /* packet number */ + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 1 /* short len */ + 1 /* packet number */ Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Len()).To(Equal(expectedLen)) @@ -246,7 +234,7 @@ var _ = Describe("Header", func() { }, PacketNumberLen: protocol.PacketNumberLen2, } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Len()).To(Equal(expectedLen)) @@ -263,7 +251,7 @@ var _ = Describe("Header", func() { }, PacketNumberLen: protocol.PacketNumberLen2, } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* long len */ + 2 /* packet number */ + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* long len */ + 2 /* packet number */ Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Len()).To(Equal(expectedLen)) @@ -281,7 +269,7 @@ var _ = Describe("Header", func() { }, PacketNumberLen: protocol.PacketNumberLen2, } - expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ + expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Len()).To(Equal(expectedLen)) diff --git a/internal/wire/header.go b/internal/wire/header.go index 805198292..84ea33cee 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -27,7 +27,7 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti if len(data) < 6 { return nil, io.EOF } - destConnIDLen, _ := decodeConnIDLen(data[5]) + destConnIDLen := int(data[5]) if len(data) < 6+destConnIDLen { return nil, io.EOF } @@ -139,16 +139,19 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { if h.Version != 0 && h.typeByte&0x40 == 0 { return errors.New("not a QUIC packet") } - connIDLenByte, err := b.ReadByte() + destConnIDLen, err := b.ReadByte() if err != nil { return err } - dcil, scil := decodeConnIDLen(connIDLenByte) - h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil) + h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) if err != nil { return err } - h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil) + srcConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) if err != nil { return err } @@ -172,8 +175,11 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { } if h.Type == protocol.PacketTypeRetry { - odcil := decodeSingleConnIDLen(h.typeByte & 0xf) - h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) + origDestConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, int(origDestConnIDLen)) if err != nil { return err } @@ -238,14 +244,3 @@ func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*Ex func (h *Header) toExtendedHeader() *ExtendedHeader { return &ExtendedHeader{Header: *h} } - -func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { - return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf) -} - -func decodeSingleConnIDLen(enc uint8) int { - if enc == 0 { - return 0 - } - return int(enc) + 3 -} diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 6141a797c..3e93c98c7 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -158,8 +158,9 @@ var _ = Describe("Header Parsing", func() { srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} data := []byte{0xc0 ^ 0x3} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x61) // connection ID lengths + data = append(data, 0x9) // dest conn id length data = append(data, destConnID...) + data = append(data, 0x4) // src conn id length data = append(data, srcConnID...) data = append(data, encodeVarInt(6)...) // token length data = append(data, []byte("foobar")...) // token @@ -204,9 +205,10 @@ var _ = Describe("Header Parsing", func() { data := []byte{ 0xc0, 0xde, 0xad, 0xbe, 0xef, - 0x55, // connection ID length - 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, - 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, + 0x8, // dest conn ID len + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID + 0x8, // src conn ID len + 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes } hdr, _, rest, err := ParsePacket(data, 0) @@ -221,7 +223,8 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header without a destination connection ID", func() { data := []byte{0xc0 ^ 0x1<<4} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x01) // connection ID lengths + data = append(data, 0x0) // dest conn ID len + data = append(data, 0x4) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) @@ -235,8 +238,9 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header without a source connection ID", func() { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x70) // connection ID lengths - data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, 0xa) // dest conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID + data = append(data, 0x0) // src conn ID len data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data, 0) @@ -248,7 +252,7 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header with a 2 byte packet number", func() { data := []byte{0xc0 ^ 0x1} data = appendVersion(data, versionIETFFrames) // version number - data = append(data, 0x0) // connection ID lengths + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(0)...) // token length data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0x1, 0x23}...) @@ -266,7 +270,8 @@ var _ = Describe("Header Parsing", func() { It("parses a Retry packet", func() { data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x0) // connection ID lengths + data = append(data, []byte{0x0, 0x0}...) // dest and src conn ID lengths + data = append(data, 0xa) // orig dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token hdr, pdata, rest, err := ParsePacket(data, 0) @@ -293,7 +298,7 @@ var _ = Describe("Header Parsing", func() { It("errors if the 5th or 6th bit are set", func() { data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x0) // connection ID lengths + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(2)...) // length data = append(data, []byte{0x12, 0x34}...) // packet number hdr, _, _, err := ParsePacket(data, 0) @@ -308,9 +313,10 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the header", func() { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x55) // connection ID lengths - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // destination connection ID - data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // source connection ID + data = append(data, 0x8) // dest conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID + data = append(data, 0x8) // src conn ID len + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID for i := 0; i < len(data); i++ { _, _, _, err := ParsePacket(data[:i], 0) Expect(err).To(Equal(io.EOF)) @@ -320,8 +326,8 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the extended header", func() { data := []byte{0xc0 | 0x2<<4 | 0x3} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x0) // connection ID lengths - data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, encodeVarInt(0)...) // length hdrLen := len(data) data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number for i := hdrLen; i < len(data); i++ { @@ -337,8 +343,8 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, for a Retry packet", func() { data := []byte{0xc0 ^ 0x3<<4} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x0) // connection ID lengths - data = append(data, 0x97) // Orig Destination Connection ID length + data = append(data, []byte{0x0, 0x0}...) // connection ID lengths + data = append(data, 0xa) // Orig Destination Connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID hdrLen := len(data) for i := hdrLen; i < len(data); i++ { diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index bfc27af77..ae097f09d 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -17,12 +17,9 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. buf.WriteByte(r[0] | 0xc0) utils.BigEndian.WriteUint32(buf, 0) // version 0 - connIDLen, err := encodeConnIDLen(destConnID, srcConnID) - if err != nil { - return nil, err - } - buf.WriteByte(connIDLen) + buf.WriteByte(uint8(destConnID.Len())) buf.Write(destConnID) + buf.WriteByte(uint8(srcConnID.Len())) buf.Write(srcConnID) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v))