forked from quic-go/quic-go
reject packets that use the wrong or unknown packet types
This commit is contained in:
@@ -63,13 +63,13 @@ var _ = Describe("Header", func() {
|
|||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
err := (&Header{
|
err := (&Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: 3,
|
Type: protocol.PacketType0RTT,
|
||||||
PacketNumber: 0x42,
|
PacketNumber: 0x42,
|
||||||
}).writeHeader(buf)
|
}).writeHeader(buf)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown)
|
hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(hdr.Type).To(BeEquivalentTo(3))
|
Expect(hdr.Type).To(Equal(protocol.PacketType0RTT))
|
||||||
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
|
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
|
||||||
Expect(hdr.isPublicHeader).To(BeFalse())
|
Expect(hdr.isPublicHeader).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, e
|
|||||||
return parseShortHeader(b, typeByte)
|
return parseShortHeader(b, typeByte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseLongHeader(b *bytes.Reader, packetSentBy protocol.Perspective, typeByte byte) (*Header, error) {
|
func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) {
|
||||||
connID, err := utils.BigEndian.ReadUint64(b)
|
connID, err := utils.BigEndian.ReadUint64(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -34,18 +34,25 @@ func parseLongHeader(b *bytes.Reader, packetSentBy protocol.Perspective, typeByt
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
packetType := protocol.PacketType(typeByte & 0x7f)
|
||||||
|
if sentBy == protocol.PerspectiveClient && (packetType != protocol.PacketTypeInitial && packetType != protocol.PacketTypeCleartext && packetType != protocol.PacketType0RTT) {
|
||||||
|
if packetType == protocol.PacketTypeVersionNegotiation {
|
||||||
|
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "sent by the client")
|
||||||
|
}
|
||||||
|
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType))
|
||||||
|
}
|
||||||
|
if sentBy == protocol.PerspectiveServer && (packetType != protocol.PacketTypeVersionNegotiation && packetType != protocol.PacketTypeRetry && packetType != protocol.PacketTypeCleartext) {
|
||||||
|
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType))
|
||||||
|
}
|
||||||
h := &Header{
|
h := &Header{
|
||||||
Type: protocol.PacketType(typeByte & 0x7f),
|
Type: packetType,
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
ConnectionID: protocol.ConnectionID(connID),
|
ConnectionID: protocol.ConnectionID(connID),
|
||||||
PacketNumber: protocol.PacketNumber(pn),
|
PacketNumber: protocol.PacketNumber(pn),
|
||||||
PacketNumberLen: protocol.PacketNumberLen4,
|
PacketNumberLen: protocol.PacketNumberLen4,
|
||||||
Version: protocol.VersionNumber(v),
|
Version: protocol.VersionNumber(v),
|
||||||
}
|
}
|
||||||
if h.Type == 0x1 { // Version Negotiation Packet
|
if h.Type == protocol.PacketTypeVersionNegotiation {
|
||||||
if packetSentBy == protocol.PerspectiveClient {
|
|
||||||
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "sent by the client")
|
|
||||||
}
|
|
||||||
if b.Len() == 0 {
|
if b.Len() == 0 {
|
||||||
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package wire
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,22 +18,20 @@ import (
|
|||||||
var _ = Describe("IETF draft Header", func() {
|
var _ = Describe("IETF draft Header", func() {
|
||||||
Context("parsing", func() {
|
Context("parsing", func() {
|
||||||
Context("long headers", func() {
|
Context("long headers", func() {
|
||||||
var data []byte
|
generatePacket := func(t protocol.PacketType) []byte {
|
||||||
|
return []byte{
|
||||||
BeforeEach(func() {
|
0x80 ^ uint8(t),
|
||||||
data = []byte{
|
|
||||||
0x80 ^ 0x3,
|
|
||||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||||
0xde, 0xca, 0xfb, 0xad, // packet number
|
0xde, 0xca, 0xfb, 0xad, // packet number
|
||||||
0x1, 0x2, 0x3, 0x4, // version number
|
0x1, 0x2, 0x3, 0x4, // version number
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
It("parses a long header", func() {
|
It("parses a long header", func() {
|
||||||
b := bytes.NewReader(data)
|
b := bytes.NewReader(generatePacket(protocol.PacketTypeInitial))
|
||||||
h, err := parseHeader(b, protocol.PerspectiveClient)
|
h, err := parseHeader(b, protocol.PerspectiveClient)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(h.Type).To(Equal(protocol.PacketType(3)))
|
Expect(h.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
Expect(h.IsLongHeader).To(BeTrue())
|
Expect(h.IsLongHeader).To(BeTrue())
|
||||||
Expect(h.OmitConnectionID).To(BeFalse())
|
Expect(h.OmitConnectionID).To(BeFalse())
|
||||||
Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337)))
|
Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337)))
|
||||||
@@ -42,7 +41,26 @@ var _ = Describe("IETF draft Header", func() {
|
|||||||
Expect(b.Len()).To(BeZero())
|
Expect(b.Len()).To(BeZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("rejects packets sent by the client that use packet types for packets sent by the server", func() {
|
||||||
|
b := bytes.NewReader(generatePacket(protocol.PacketTypeRetry))
|
||||||
|
_, err := parseHeader(b, protocol.PerspectiveClient)
|
||||||
|
Expect(err).To(MatchError(fmt.Sprintf("InvalidPacketHeader: Received packet with invalid packet type: %d", protocol.PacketTypeRetry)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects packets sent by the client that use packet types for packets sent by the server", func() {
|
||||||
|
b := bytes.NewReader(generatePacket(protocol.PacketType0RTT))
|
||||||
|
_, err := parseHeader(b, protocol.PerspectiveServer)
|
||||||
|
Expect(err).To(MatchError(fmt.Sprintf("InvalidPacketHeader: Received packet with invalid packet type: %d", protocol.PacketType0RTT)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects packets sent with an unknown packet type", func() {
|
||||||
|
b := bytes.NewReader(generatePacket(42))
|
||||||
|
_, err := parseHeader(b, protocol.PerspectiveServer)
|
||||||
|
Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42"))
|
||||||
|
})
|
||||||
|
|
||||||
It("errors on EOF", func() {
|
It("errors on EOF", func() {
|
||||||
|
data := generatePacket(protocol.PacketTypeInitial)
|
||||||
for i := 0; i < len(data); i++ {
|
for i := 0; i < len(data); i++ {
|
||||||
_, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient)
|
_, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient)
|
||||||
Expect(err).To(Equal(io.EOF))
|
Expect(err).To(Equal(io.EOF))
|
||||||
@@ -50,14 +68,13 @@ var _ = Describe("IETF draft Header", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("Version Negotiation Packets", func() {
|
Context("Version Negotiation Packets", func() {
|
||||||
BeforeEach(func() {
|
|
||||||
data[0] = 0x80 ^ 0x1 // set the type byte to Version Negotiation Packet
|
|
||||||
})
|
|
||||||
|
|
||||||
It("parses", func() {
|
It("parses", func() {
|
||||||
data = append(data, []byte{
|
data := append(
|
||||||
0x22, 0x33, 0x44, 0x55,
|
generatePacket(protocol.PacketTypeVersionNegotiation),
|
||||||
0x33, 0x44, 0x55, 0x66}...,
|
[]byte{
|
||||||
|
0x22, 0x33, 0x44, 0x55,
|
||||||
|
0x33, 0x44, 0x55, 0x66,
|
||||||
|
}...,
|
||||||
)
|
)
|
||||||
b := bytes.NewReader(data)
|
b := bytes.NewReader(data)
|
||||||
h, err := parseHeader(b, protocol.PerspectiveServer)
|
h, err := parseHeader(b, protocol.PerspectiveServer)
|
||||||
@@ -70,21 +87,27 @@ var _ = Describe("IETF draft Header", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("errors if it contains versions of the wrong length", func() {
|
It("errors if it contains versions of the wrong length", func() {
|
||||||
data = append(data, []byte{0x22, 0x33}...) // too short. Should be 4 bytes.
|
data := append(
|
||||||
|
generatePacket(protocol.PacketTypeVersionNegotiation),
|
||||||
|
[]byte{0x22, 0x33}..., // too short. Should be 4 bytes.
|
||||||
|
)
|
||||||
b := bytes.NewReader(data)
|
b := bytes.NewReader(data)
|
||||||
_, err := parseHeader(b, protocol.PerspectiveServer)
|
_, err := parseHeader(b, protocol.PerspectiveServer)
|
||||||
Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket))
|
Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if it was sent by the client", func() {
|
It("errors if it was sent by the client", func() {
|
||||||
data = append(data, []byte{0x22, 0x33, 0x44, 0x55}...)
|
data := append(
|
||||||
|
generatePacket(protocol.PacketTypeVersionNegotiation),
|
||||||
|
[]byte{0x22, 0x33, 0x44, 0x55}...,
|
||||||
|
)
|
||||||
b := bytes.NewReader(data)
|
b := bytes.NewReader(data)
|
||||||
_, err := parseHeader(b, protocol.PerspectiveClient)
|
_, err := parseHeader(b, protocol.PerspectiveClient)
|
||||||
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: sent by the client"))
|
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: sent by the client"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if the version list is emtpy", func() {
|
It("errors if the version list is emtpy", func() {
|
||||||
b := bytes.NewReader(data)
|
b := bytes.NewReader(generatePacket(protocol.PacketTypeVersionNegotiation))
|
||||||
_, err := parseHeader(b, protocol.PerspectiveServer)
|
_, err := parseHeader(b, protocol.PerspectiveServer)
|
||||||
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list"))
|
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list"))
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user