diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index c102691a4..24e568626 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -26,13 +26,13 @@ type PacketType uint8 const ( // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = 0x7f + PacketTypeInitial PacketType = 1 + iota // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry PacketType = 0x7e + PacketTypeRetry // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake PacketType = 0x7d + PacketTypeHandshake // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT PacketType = 0x7c + PacketType0RTT ) func (t PacketType) String() string { diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 21dd4b19f..46524133e 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -59,7 +59,18 @@ func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) erro } func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error { - b.WriteByte(byte(0x80 | h.Type)) + var packetType uint8 + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0x7f + case protocol.PacketTypeRetry: + packetType = 0x7e + case protocol.PacketTypeHandshake: + packetType = 0x7d + case protocol.PacketType0RTT: + packetType = 0x7c + } + b.WriteByte(0x80 | packetType) utils.BigEndian.WriteUint32(b, uint32(h.Version)) connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) if err != nil { diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 661aa1868..35b9158e8 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -38,13 +38,13 @@ var _ = Describe("Header", func() { SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, Version: 0x1020304, Length: 0xcafe, - Type: 0x5, + Type: protocol.PacketTypeHandshake, }, PacketNumber: 0xdecaf, PacketNumberLen: protocol.PacketNumberLen4, }).Write(buf, versionIETFHeader)).To(Succeed()) expected := []byte{ - 0x80 ^ 0x5, + 0x80 ^ 0x7d, 0x1, 0x2, 0x3, 0x4, // version number 0x35, // connection ID lengths 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID @@ -127,7 +127,7 @@ var _ = Describe("Header", func() { OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, }}).Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Bytes()[:6]).To(Equal([]byte{ - 0x80 ^ uint8(protocol.PacketTypeRetry), + 0x80 | 0x7e, 0x1, 0x2, 0x3, 0x4, // version number 0x0, // connection ID lengths)) })) @@ -228,6 +228,7 @@ var _ = Describe("Header", func() { h := &ExtendedHeader{ Header: Header{ IsLongHeader: true, + Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, Length: 1, @@ -244,6 +245,7 @@ var _ = Describe("Header", func() { h := &ExtendedHeader{ Header: Header{ IsLongHeader: true, + Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, Length: 1500, @@ -260,10 +262,10 @@ var _ = Describe("Header", func() { h := &ExtendedHeader{ Header: Header{ IsLongHeader: true, + Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Length: 1500, - Type: protocol.PacketTypeInitial, }, PacketNumberLen: protocol.PacketNumberLen2, } diff --git a/internal/wire/header.go b/internal/wire/header.go index f687354f7..e665428fb 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -99,9 +99,17 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { return nil } - h.Type = protocol.PacketType(h.typeByte & 0x7f) - if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake { - return qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) + switch h.typeByte & 0x7f { + case 0x7f: + h.Type = protocol.PacketTypeInitial + case 0x7e: + h.Type = protocol.PacketTypeRetry + case 0x7d: + h.Type = protocol.PacketTypeHandshake + case 0x7c: + h.Type = protocol.PacketType0RTT + default: + return qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.typeByte&0x7f)) } if h.Type == protocol.PacketTypeRetry { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 93100e6b3..7198a2465 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -73,7 +73,7 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header", func() { destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - data := []byte{0x80 ^ uint8(protocol.PacketTypeInitial)} + data := []byte{0x80 ^ 0x7f} data = appendVersion(data, versionIETFFrames) data = append(data, 0x61) // connection ID lengths data = append(data, destConnID...) @@ -103,7 +103,7 @@ var _ = Describe("Header Parsing", func() { It("stops parsing when encountering an unsupported version", func() { data := []byte{ - 0x80 ^ uint8(protocol.PacketTypeInitial), + 0x80 ^ 0x7f, 0xde, 0xad, 0xbe, 0xef, 0x55, // connection ID length 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, @@ -121,7 +121,7 @@ var _ = Describe("Header Parsing", func() { }) It("parses a Long Header without a destination connection ID", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeHandshake)} + data := []byte{0x80 ^ 0x7d} data = appendVersion(data, versionIETFFrames) data = append(data, 0x01) // connection ID lengths data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID @@ -134,7 +134,7 @@ var _ = Describe("Header Parsing", func() { }) It("parses a Long Header without a source connection ID", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeHandshake)} + data := []byte{0x80 ^ 0x7d} data = appendVersion(data, versionIETFFrames) data = append(data, 0x70) // connection ID lengths data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID @@ -147,7 +147,7 @@ var _ = Describe("Header Parsing", func() { }) It("parses a Long Header with a 2 byte packet number", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeInitial)} + data := []byte{0x80 ^ 0x7f} data = appendVersion(data, versionIETFFrames) // version number data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(0)...) // token length @@ -165,7 +165,7 @@ var _ = Describe("Header Parsing", func() { }) It("parses a Retry packet", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeRetry)} + data := []byte{0x80 ^ 0x7e} data = appendVersion(data, versionIETFFrames) data = append(data, 0x0) // connection ID lengths data = append(data, 0x97) // Orig Destination Connection ID length @@ -185,20 +185,22 @@ var _ = Describe("Header Parsing", func() { Expect((&ExtendedHeader{ Header: Header{ IsLongHeader: true, - Type: 42, + Type: protocol.PacketTypeHandshake, // will be overwritten later SrcConnectionID: srcConnID, Version: versionIETFFrames, }, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen1, }).Write(buf, protocol.VersionTLS)).To(Succeed()) - b := bytes.NewReader(buf.Bytes()) + data := buf.Bytes() + data[0] = 0x80 | 42 + b := bytes.NewReader(data) _, err := ParseHeader(b, 0) Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42")) }) It("errors if the token length is too large", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeInitial)} + data := []byte{0x80 ^ 0x7e} data = appendVersion(data, versionIETFFrames) data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) @@ -211,7 +213,7 @@ var _ = Describe("Header Parsing", func() { }) It("errors on EOF, when parsing the header", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeInitial)} + data := []byte{0x80 ^ 0x7f} data = appendVersion(data, versionIETFFrames) data = append(data, 0x55) // connection ID lengths data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // destination connection ID @@ -223,7 +225,7 @@ var _ = Describe("Header Parsing", func() { }) It("errors on EOF, when parsing the extended header", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeHandshake)} + data := []byte{0x80 ^ 0x7d} data = appendVersion(data, versionIETFFrames) data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(0x1337)...) @@ -239,7 +241,7 @@ var _ = Describe("Header Parsing", func() { }) It("errors on EOF, for a Retry packet", func() { - data := []byte{0x80 ^ uint8(protocol.PacketTypeRetry)} + data := []byte{0x80 ^ 0x7e} data = appendVersion(data, versionIETFFrames) data = append(data, 0x0) // connection ID lengths data = append(data, 0x97) // Orig Destination Connection ID length