diff --git a/packet_packer.go b/packet_packer.go index c8bef1e2..f616db17 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -460,6 +460,9 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { header.PacketNumberLen = protocol.PacketNumberLen4 header.IsLongHeader = true + // Set the payload len to maximum size. + // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. + header.PayloadLen = p.maxPacketSize if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { header.Type = protocol.PacketTypeInitial } else { @@ -494,6 +497,20 @@ func (p *packetPacker) writeAndSealPacket( raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) + // the payload length is only needed for Long Headers + if header.IsLongHeader { + if header.Type == protocol.PacketTypeInitial { + headerLen, _ := header.GetLength(p.perspective, p.version) + header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen + } else { + payloadLen := protocol.ByteCount(sealer.Overhead()) + for _, frame := range payloadFrames { + payloadLen += frame.Length(p.version) + } + header.PayloadLen = payloadLen + } + } + if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } diff --git a/packet_packer_test.go b/packet_packer_test.go index 74406b90..a67c3b51 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -247,6 +247,24 @@ var _ = Describe("Packet packer", func() { }) }) + It("sets the payload length for packets containing crypto data", func() { + packer.version = versionIETFFrames + f := &wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Offset: 0x1337, + Data: []byte("foobar"), + } + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(f) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // parse the packet + r := bytes.NewReader(p.raw) + hdr, err := wire.ParseHeaderSentByServer(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PayloadLen).To(BeEquivalentTo(r.Len())) + }) + It("packs a CONNECTION_CLOSE", func() { ccf := wire.ConnectionCloseFrame{ ErrorCode: 0x1337, @@ -571,6 +589,31 @@ var _ = Describe("Packet packer", func() { Expect(p).To(BeNil()) }) + It("packs a maximum size crypto packet", func() { + var f *wire.StreamFrame + packer.version = versionIETFFrames + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.StreamFrame { + f = &wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Offset: 0x1337, + } + f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version))) + return f + }) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + expectedPacketLen := packer.maxPacketSize - protocol.NonForwardSecurePacketSizeReduction + Expect(p.raw).To(HaveLen(int(expectedPacketLen))) + Expect(p.header.IsLongHeader).To(BeTrue()) + // parse the packet + r := bytes.NewReader(p.raw) + hdr, err := wire.ParseHeaderSentByServer(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PayloadLen).To(BeEquivalentTo(r.Len())) + }) + It("sends unencrypted stream data on the crypto stream", func() { f := &wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), @@ -733,6 +776,25 @@ var _ = Describe("Packet packer", func() { Expect(sf.DataLenPresent).To(BeTrue()) }) + It("set the correct payload length for an Initial packet", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData().Return(true) + mockStreamFramer.EXPECT().PopCryptoStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Data: []byte("foobar"), + }) + packer.version = protocol.VersionTLS + packer.hasSentPacket = false + packer.perspective = protocol.PerspectiveClient + packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // parse the header and check the values + r := bytes.NewReader(packet.raw) + hdr, err := wire.ParseHeaderSentByClient(r) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PayloadLen).To(BeEquivalentTo(r.Len())) + }) + It("packs a retransmission for an Initial packet", func() { packer.version = versionIETFFrames packer.perspective = protocol.PerspectiveClient diff --git a/server_tls.go b/server_tls.go index b5cbea85..7424a409 100644 --- a/server_tls.go +++ b/server_tls.go @@ -179,18 +179,19 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, if alert == mint.AlertStatelessRetry { // the HelloRetryRequest was written to the bufferConn // Take that data and write send a Retry packet + f := &wire.StreamFrame{ + StreamID: version.CryptoStreamID(), + Data: bc.GetDataForWriting(), + } replyHdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, DestConnectionID: hdr.SrcConnectionID, SrcConnectionID: hdr.DestConnectionID, + PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()), PacketNumber: hdr.PacketNumber, // echo the client's packet number Version: version, } - f := &wire.StreamFrame{ - StreamID: version.CryptoStreamID(), - Data: bc.GetDataForWriting(), - } data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger) if err != nil { return nil, err diff --git a/server_tls_test.go b/server_tls_test.go index 90f6521f..b1246ab6 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -117,11 +117,13 @@ var _ = Describe("Stateless TLS handling", func() { hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) server.HandleInitial(nil, hdr, data) Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS) + r := bytes.NewReader(conn.dataWritten.Bytes()) + replyHdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(replyHdr.PayloadLen).To(BeEquivalentTo(r.Len())) Expect(sessionChan).ToNot(Receive()) })