diff --git a/client.go b/client.go index e408e5a0d..d13dd8148 100644 --- a/client.go +++ b/client.go @@ -249,7 +249,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := wire.ParseHeader(r, protocol.PerspectiveServer, c.version) + hdr, err := wire.ParseHeaderSentByServer(r, c.version) if err != nil { utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) // drop this packet if we can't parse the header diff --git a/internal/wire/header.go b/internal/wire/header.go index 05c9ba4b7..5bcabb364 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -4,7 +4,6 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) // Header is the header of a QUIC packet. @@ -32,59 +31,56 @@ type Header struct { isPublicHeader bool } -// ParseHeader parses the header. -func ParseHeader(b *bytes.Reader, sentBy protocol.Perspective, version protocol.VersionNumber) (*Header, error) { - var typeByte uint8 - if version == protocol.VersionUnknown { - var err error - typeByte, err = b.ReadByte() - if err != nil { - return nil, err - } - _ = b.UnreadByte() // unread the type byte - } - - // There are two conditions this is a header in the IETF Header format: - // 1. We already know the version (because this is a packet that belongs to an exisitng session). - // 2. If this is a new packet, it must have the Long Format, which has the 0x80 bit set (which is always 0 in gQUIC). - // There's a third option: This could be a packet with Short Format that arrives after a server lost state. - // In that case, we'll try parsing the header as a gQUIC Public Header. - if version.UsesTLS() || (version == protocol.VersionUnknown && typeByte&0x80 > 0) { - return parseHeader(b, sentBy) - } - - // This is a gQUIC Public Header. - hdr, err := parsePublicHeader(b, sentBy, version) +// ParseHeaderSentByServer parses the header for a packet that was sent by the server. +func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (*Header, error) { + typeByte, err := b.ReadByte() if err != nil { return nil, err } - hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later - return hdr, nil + _ = b.UnreadByte() // unread the type byte + + var isPublicHeader bool + // As a client, we know the version of the packet that the server sent, except for Version Negotiation Packets. + // Both gQUIC and IETF QUIC Version Negotiation Packets have 0x1 set. + if typeByte&0x1 > 0 { + // IETF QUIC Version Negotiation Packets are sent with the Long Header (indicated by the 0x80 bit) + // gQUIC always has 0x80 unset + isPublicHeader = typeByte&0x80 == 0 + } else { + // For all packets that are not Version Negotiation Packets, the client knows the version that this packet was sent with + isPublicHeader = !version.UsesTLS() + } + return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) } -// PeekConnectionID parses the connection ID from a QUIC packet's public header, sent by the client. -// This function should not be called for packets sent by the server, since on these packets the Connection ID could be omitted. -// If no error occurs, it restores the read position in the bytes.Reader. -func PeekConnectionID(b *bytes.Reader) (protocol.ConnectionID, error) { - var connectionID protocol.ConnectionID - if _, err := b.ReadByte(); err != nil { - return 0, err - } - // unread the public flag byte - defer b.UnreadByte() - - // Assume that the packet contains the Connection ID. - // This is a valid assumption for all packets sent by the client, because the server doesn't allow the ommision of the Connection ID. - connID, err := utils.BigEndian.ReadUint64(b) +// ParseHeaderSentByClient parses the header for a packet that was sent by the client. +func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { + typeByte, err := b.ReadByte() if err != nil { - return 0, err + return nil, err } - connectionID = protocol.ConnectionID(connID) - // unread the connection ID - for i := 0; i < 8; i++ { - b.UnreadByte() + _ = b.UnreadByte() // unread the type byte + + // If this is a gQUIC header 0x80 and 0x40 will be set to 0. + // If this is an IETF QUIC header there are two options: + // * either 0x80 will be 1 (for the Long Header) + // * or 0x40 (the Connection ID Flag) will be 0 (for the Short Header), since we don't the client to omit it + isPublicHeader := typeByte&0xc0 == 0 + + return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) +} + +func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) { + // This is a gQUIC Public Header. + if isPublicHeader { + hdr, err := parsePublicHeader(b, sentBy) + if err != nil { + return nil, err + } + hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return hdr, nil } - return connectionID, nil + return parseHeader(b, sentBy) } // Write writes the Header. diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 195c64d73..2e501bb96 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -18,29 +18,6 @@ var _ = Describe("Header", func() { versionIETFHeader = protocol.VersionTLS // a QUIC version taht uses the IETF Header format ) - Context("peeking the connection ID", func() { - It("gets the connection ID", func() { - b := bytes.NewReader([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x51, 0x30, 0x33, 0x34, 0x01}) - len := b.Len() - connID, err := PeekConnectionID(b) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) - Expect(b.Len()).To(Equal(len)) - }) - - It("errors if the header is too short", func() { - b := bytes.NewReader([]byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b}) - _, err := PeekConnectionID(b) - Expect(err).To(HaveOccurred()) - }) - - It("errors if the header is empty", func() { - b := bytes.NewReader([]byte{}) - _, err := PeekConnectionID(b) - Expect(err).To(HaveOccurred()) - }) - }) - Context("parsing", func() { It("parses an IETF draft header, when the QUIC version supports TLS", func() { buf := &bytes.Buffer{} @@ -52,7 +29,7 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen2, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) - hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, versionIETFHeader) + hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(hdr.KeyPhase).To(BeEquivalentTo(1)) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) @@ -67,7 +44,7 @@ var _ = Describe("Header", func() { PacketNumber: 0x42, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) - hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown) + hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) @@ -84,7 +61,7 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen6, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) - hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown) + hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.Version).To(Equal(versionPublicHeader)) @@ -100,7 +77,7 @@ var _ = Describe("Header", func() { DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), }).writePublicHeader(buf, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) - hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer, versionPublicHeader) + hdr, err := ParseHeaderSentByServer(bytes.NewReader(buf.Bytes()), versionPublicHeader) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.DiversificationNonce).To(HaveLen(32)) @@ -117,13 +94,37 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen6, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) - _, err = ParseHeader(bytes.NewReader(buf.Bytes()[0:12]), protocol.PerspectiveClient, protocol.VersionUnknown) + _, err = ParseHeaderSentByClient(bytes.NewReader(buf.Bytes()[0:12])) Expect(err).To(MatchError(io.EOF)) }) It("errors when given no data", func() { - _, err := ParseHeader(bytes.NewReader([]byte{}), protocol.PerspectiveClient, protocol.VersionUnknown) + _, err := ParseHeaderSentByServer(bytes.NewReader([]byte{}), protocol.VersionUnknown) Expect(err).To(MatchError(io.EOF)) + _, err = ParseHeaderSentByClient(bytes.NewReader([]byte{})) + Expect(err).To(MatchError(io.EOF)) + }) + + It("parses a gQUIC Version Negotiation Packet", func() { + versions := []protocol.VersionNumber{0x13, 0x37} + data := ComposeGQUICVersionNegotiation(0x42, versions) + hdr, err := ParseHeaderSentByServer(bytes.NewReader(data), protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.isPublicHeader).To(BeTrue()) + Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x42))) + Expect(hdr.SupportedVersions).To(Equal(versions)) + }) + + It("parses a gQUIC Version Negotiation Packet", func() { + versions := []protocol.VersionNumber{0x13, 0x37} + data := ComposeVersionNegotiation(0x42, 0x77, versions) + hdr, err := ParseHeaderSentByServer(bytes.NewReader(data), protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.isPublicHeader).To(BeFalse()) + Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x42))) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x77))) + Expect(hdr.SupportedVersions).To(Equal(versions)) + Expect(hdr.Type).To(Equal(protocol.PacketTypeVersionNegotiation)) }) }) @@ -137,7 +138,7 @@ var _ = Describe("Header", func() { } err := hdr.Write(buf, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) - _, err = parsePublicHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer, versionPublicHeader) + _, err = parsePublicHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.isPublicHeader).To(BeTrue()) }) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index 202052a63..ba5c8e695 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -12,9 +12,6 @@ import ( ) var ( - // ErrPacketWithUnknownVersion occurs when a packet with an unknown version is parsed. - // This can happen when the server is restarted. The client will send a packet without a version number. - ErrPacketWithUnknownVersion = errors.New("PublicHeader: Received a packet without version number, that we don't know the version for") errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") @@ -90,7 +87,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, v // parsePublicHeader parses a QUIC packet's Public Header. // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. -func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, version protocol.VersionNumber) (*Header, error) { +func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { header := &Header{} // First byte @@ -100,9 +97,6 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, versi } header.ResetFlag = publicFlagByte&0x02 > 0 header.VersionFlag = publicFlagByte&0x01 > 0 - if version == protocol.VersionUnknown && !(header.VersionFlag || header.ResetFlag) { - return nil, ErrPacketWithUnknownVersion - } // TODO: activate this check once Chrome sends the correct value // see https://github.com/lucas-clemente/quic-go/issues/232 @@ -181,12 +175,11 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, versi return nil, err } header.Version = protocol.VersionNumber(versionTag) - version = header.Version } // Packet number if header.hasPacketNumber(packetSentBy) { - packetNumber, err := utils.GetByteOrder(version).ReadUintN(b, uint8(header.PacketNumberLen)) + packetNumber, err := utils.BigEndian.ReadUintN(b, uint8(header.PacketNumberLen)) if err != nil { return nil, err } diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index d62f5cafd..55f507509 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -19,7 +19,7 @@ var _ = Describe("Public Header", func() { ver := make([]byte, 4) binary.BigEndian.PutUint32(ver, uint32(protocol.SupportedVersions[0])) b := bytes.NewReader(append(append([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, ver...), 0x01)) - hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.ResetFlag).To(BeFalse()) @@ -32,13 +32,13 @@ var _ = Describe("Public Header", func() { It("does not accept an omittedd connection ID as a server", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) + _, err := parsePublicHeader(b, protocol.PerspectiveClient) Expect(err).To(MatchError(errReceivedOmittedConnectionID)) }) It("accepts aan d connection ID as a client", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.OmitConnectionID).To(BeTrue()) Expect(hdr.ConnectionID).To(BeZero()) @@ -47,13 +47,13 @@ var _ = Describe("Public Header", func() { It("rejects 0 as a connection ID", func() { b := bytes.NewReader([]byte{0x09, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x51, 0x30, 0x33, 0x30, 0x01}) - _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveClient) Expect(err).To(MatchError(errInvalidConnectionID)) }) It("reads a PublicReset packet", func() { b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.ConnectionID).ToNot(BeZero()) @@ -61,7 +61,7 @@ var _ = Describe("Public Header", func() { It("parses a public reset packet", func() { b := bytes.NewReader([]byte{0xa, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.VersionFlag).To(BeFalse()) @@ -72,25 +72,19 @@ var _ = Describe("Public Header", func() { divNonce := []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} Expect(divNonce).To(HaveLen(32)) b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37)) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ConnectionID).To(Not(BeZero())) Expect(hdr.DiversificationNonce).To(Equal(divNonce)) Expect(b.Len()).To(BeZero()) }) - It("returns an unknown version error when receiving a packet without a version for which the version is not given", func() { - b := bytes.NewReader([]byte{0x10, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0xef}) - _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) - Expect(err).To(MatchError(ErrPacketWithUnknownVersion)) - }) - PIt("rejects diversification nonces sent by the client", func() { b := bytes.NewReader([]byte{0x0c, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0x01, }) - _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) + _, err := parsePublicHeader(b, protocol.PerspectiveClient) Expect(err).To(MatchError("diversification nonces should only be sent by servers")) }) @@ -103,7 +97,7 @@ var _ = Describe("Public Header", func() { It("parses version negotiation packets sent by the server", func() { b := bytes.NewReader(ComposeGQUICVersionNegotiation(0x1337, protocol.SupportedVersions)) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.Version).To(BeZero()) // unitialized @@ -113,7 +107,7 @@ var _ = Describe("Public Header", func() { It("errors if it doesn't contain any versions", func() { b := bytes.NewReader([]byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}) - _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) @@ -123,7 +117,7 @@ var _ = Describe("Public Header", func() { data = appendVersion(data, protocol.SupportedVersions[0]) data = appendVersion(data, 99) // unsupported version b := bytes.NewReader(data) - hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{1, protocol.SupportedVersions[0], 99})) @@ -134,54 +128,46 @@ var _ = Describe("Public Header", func() { data := ComposeGQUICVersionNegotiation(0x1337, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) - _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) }) Context("Packet Number lengths", func() { - Context("in big endian encoding", func() { - version := protocol.Version39 + It("accepts 1-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde}) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(b.Len()).To(BeZero()) + }) - BeforeEach(func() { - Expect(utils.GetByteOrder(version)).To(Equal(utils.BigEndian)) - }) + It("accepts 2-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca}) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeca))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) - It("accepts 1-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(b.Len()).To(BeZero()) - }) + It("accepts 4-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xadfbcade))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(b.Len()).To(BeZero()) + }) - It("accepts 2-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeca))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts 4-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xadfbcade))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(b.Len()).To(BeZero()) - }) - - It("accepts 6-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x2342adfbcade))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) - Expect(b.Len()).To(BeZero()) - }) + It("accepts 6-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x2342adfbcade))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) + Expect(b.Len()).To(BeZero()) }) }) }) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 869d008be..ca16431d5 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -12,7 +12,7 @@ var _ = Describe("Version Negotiation Packets", func() { It("writes for gQUIC", func() { versions := []protocol.VersionNumber{1001, 1003} data := ComposeGQUICVersionNegotiation(0x1337, versions) - hdr, err := parsePublicHeader(bytes.NewReader(data), protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(bytes.NewReader(data), protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) diff --git a/server.go b/server.go index 98fcd10dc..10c55805e 100644 --- a/server.go +++ b/server.go @@ -219,38 +219,25 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet rcvTime := time.Now() r := bytes.NewReader(packet) - connID, err := wire.PeekConnectionID(r) - if err != nil { - return qerr.Error(qerr.InvalidPacketHeader, err.Error()) - } - - s.sessionsMutex.RLock() - session, ok := s.sessions[connID] - s.sessionsMutex.RUnlock() - - if ok && session == nil { - // Late packet for closed session - return nil - } - - version := protocol.VersionUnknown - if ok { - version = session.GetVersion() - } - - hdr, err := wire.ParseHeader(r, protocol.PerspectiveClient, version) - if err == wire.ErrPacketWithUnknownVersion { - _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) - return err - } + hdr, err := wire.ParseHeaderSentByClient(r) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] + connID := hdr.ConnectionID + + s.sessionsMutex.RLock() + session, sessionKnown := s.sessions[connID] + s.sessionsMutex.RUnlock() + + if sessionKnown && session == nil { + // Late packet for closed session + return nil + } // ignore all Public Reset packets if hdr.ResetFlag { - if ok { + if sessionKnown { var pr *wire.PublicReset pr, err = wire.ParsePublicReset(r) if err != nil { @@ -264,10 +251,18 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return nil } + // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset + // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. + // TODO(#943): implement sending of IETF draft style stateless resets + if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { + _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) + return err + } + // a session is only created once the client sent a supported version // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated // it is safe to drop it - if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { return nil } @@ -289,7 +284,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } - if !ok { + if !sessionKnown { version := hdr.Version if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") diff --git a/server_test.go b/server_test.go index 0688380e8..64869d6a3 100644 --- a/server_test.go +++ b/server_test.go @@ -422,7 +422,7 @@ var _ = Describe("Server", func() { Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) r := bytes.NewReader(conn.dataWritten.Bytes()) - packet, err := wire.ParseHeader(r, protocol.PerspectiveServer, protocol.VersionUnknown) + packet, err := wire.ParseHeaderSentByServer(r, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(packet.VersionFlag).To(BeTrue()) Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) @@ -455,7 +455,7 @@ var _ = Describe("Server", func() { Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) r := bytes.NewReader(conn.dataWritten.Bytes()) - packet, err := wire.ParseHeader(r, protocol.PerspectiveServer, protocol.VersionUnknown) + packet, err := wire.ParseHeaderSentByServer(r, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(packet.Type).To(Equal(protocol.PacketTypeVersionNegotiation)) Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337)))