diff --git a/packet_packer.go b/packet_packer.go index 98d91bf6..27696501 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -130,7 +130,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con } var raw bytes.Buffer - if err := responsePublicHeader.WritePublicHeader(&raw); err != nil { + if err := responsePublicHeader.WritePublicHeader(&raw, p.version); err != nil { return nil, err } diff --git a/public_header.go b/public_header.go index 5526f197..37e85c42 100644 --- a/public_header.go +++ b/public_header.go @@ -29,10 +29,11 @@ type publicHeader struct { QuicVersion uint32 PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber + DiversificationNonce []byte } // WritePublicHeader writes a public header -func (h *publicHeader) WritePublicHeader(b *bytes.Buffer) error { +func (h *publicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.VersionNumber) error { publicFlagByte := uint8(0x00) if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet @@ -44,8 +45,17 @@ func (h *publicHeader) WritePublicHeader(b *bytes.Buffer) error { publicFlagByte |= 0x02 } if !h.TruncateConnectionID { - // TODO: Change this once we support version 33 properly - publicFlagByte |= 0x0c + if version < protocol.VersionNumber(33) { + publicFlagByte |= 0x0c + } else { + publicFlagByte |= 0x08 + } + } + if len(h.DiversificationNonce) > 0 { + if len(h.DiversificationNonce) != 32 { + return errors.New("invalid diversification nonce length") + } + publicFlagByte |= 0x04 } if !h.ResetFlag && !h.VersionFlag { @@ -67,6 +77,10 @@ func (h *publicHeader) WritePublicHeader(b *bytes.Buffer) error { utils.WriteUint64(b, uint64(h.ConnectionID)) } + if len(h.DiversificationNonce) > 0 { + b.Write(h.DiversificationNonce) + } + if !h.ResetFlag && !h.VersionFlag { switch h.PacketNumberLen { case protocol.PacketNumberLen1: diff --git a/public_header_test.go b/public_header_test.go index 21b6bd75..7826acc9 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -84,7 +84,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.WritePublicHeader(b) + hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(b.Bytes()).To(Equal([]byte{0x38 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 2, 0, 0, 0, 0, 0})) }) @@ -96,7 +96,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.WritePublicHeader(b) + hdr.WritePublicHeader(b, protocol.VersionNumber(32)) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID firstByte, _ := b.ReadByte() @@ -111,7 +111,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - hdr.WritePublicHeader(b) + hdr.WritePublicHeader(b, protocol.VersionNumber(32)) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID firstByte, _ := b.ReadByte() @@ -127,7 +127,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).To(MatchError(errResetAndVersionFlagSet)) }) @@ -139,11 +139,40 @@ var _ = Describe("Public Header", func() { PacketNumberLen: protocol.PacketNumberLen6, PacketNumber: 1, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x30, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0})) }) + It("writes proper v33 packets", func() { + b := &bytes.Buffer{} + hdr := publicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen1, + } + err := hdr.WritePublicHeader(b, protocol.VersionNumber(33)) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01})) + }) + + It("writes diversification nonces", func() { + b := &bytes.Buffer{} + hdr := publicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen1, + DiversificationNonce: bytes.Repeat([]byte{1}, 32), + } + err := hdr.WritePublicHeader(b, protocol.VersionNumber(33)) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{ + 0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0x01, + })) + }) + Context("GetLength", func() { It("errors when calling GetLength for Version Negotiation packets", func() { hdr := publicHeader{VersionFlag: true} @@ -208,7 +237,7 @@ var _ = Describe("Public Header", func() { ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) @@ -219,7 +248,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen1, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x08 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD})) }) @@ -231,7 +260,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen2, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x18 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB})) }) @@ -243,7 +272,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0x13DECAFBAD, PacketNumberLen: protocol.PacketNumberLen4, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x28 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE})) }) @@ -255,7 +284,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0xBE1337DECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.WritePublicHeader(b) + err := hdr.WritePublicHeader(b, protocol.VersionNumber(32)) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE, 0x37, 0x13})) }) diff --git a/server.go b/server.go index 89e9b8bc..f5c1aab6 100644 --- a/server.go +++ b/server.go @@ -167,7 +167,8 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { PacketNumber: 1, VersionFlag: true, } - err := responsePublicHeader.WritePublicHeader(fullReply) + // TODO: Update version number + err := responsePublicHeader.WritePublicHeader(fullReply, protocol.VersionNumber(32)) if err != nil { utils.Errorf("error composing version negotiation packet: %s", err.Error()) }