From c79dfe77b8e175ec7a50ce1fc1e0b3e170d89c80 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 10 Aug 2019 13:02:47 +0700 Subject: [PATCH] implement packing of 0-RTT packets --- packet_packer.go | 41 ++++++++++++++++++++++++++++++++++------- packet_packer_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index d31253ef..eb532cdf 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -52,6 +52,8 @@ func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { return protocol.EncryptionInitial case protocol.PacketTypeHandshake: return protocol.EncryptionHandshake + case protocol.PacketType0RTT: + return protocol.Encryption0RTT default: return protocol.EncryptionUnspecified } @@ -115,6 +117,7 @@ type packetNumberManager interface { type sealingManager interface { GetInitialSealer() (handshake.LongHeaderSealer, error) GetHandshakeSealer() (handshake.LongHeaderSealer, error) + Get0RTTSealer() (handshake.LongHeaderSealer, error) Get1RTTSealer() (handshake.ShortHeaderSealer, error) } @@ -370,12 +373,26 @@ func (p *packetPacker) packCryptoPacket( } func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - // sealer not yet available - return nil, nil + var sealer sealer + var header *wire.ExtendedHeader + var encLevel protocol.EncryptionLevel + oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() + if err == nil { + encLevel = protocol.Encryption1RTT + sealer = oneRTTSealer + header = p.getShortHeader(oneRTTSealer.KeyPhase()) + } else { + // 1-RTT sealer not yet available + if p.perspective != protocol.PerspectiveClient { + return nil, nil + } + sealer, err = p.cryptoSetup.Get0RTTSealer() + if sealer == nil || err != nil { + return nil, nil + } + encLevel = protocol.Encryption0RTT + header = p.getLongHeader(protocol.Encryption0RTT) } - header := p.getShortHeader(sealer.KeyPhase()) headerLen := header.GetLength(p.version) maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen @@ -398,12 +415,13 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer) + return p.writeAndSealPacket(header, payload, encLevel, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { var payload payload + // TODO: we don't need to request ACKs when sending 0-RTT packets if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil { payload.ack = ack payload.length += ack.Length(p.version) @@ -453,6 +471,13 @@ func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (se } hdr := p.getLongHeader(protocol.EncryptionInitial) return sealer, hdr, nil + case protocol.Encryption0RTT: + sealer, err := p.cryptoSetup.Get0RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.Encryption0RTT) + return sealer, hdr, nil case protocol.EncryptionHandshake: sealer, err := p.cryptoSetup.GetHandshakeSealer() if err != nil { @@ -495,11 +520,13 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex hdr.Token = p.token case protocol.EncryptionHandshake: hdr.Type = protocol.PacketTypeHandshake + case protocol.Encryption0RTT: + hdr.Type = protocol.PacketType0RTT } hdr.Version = p.version hdr.IsLongHeader = true - // Always send Initial and Handshake packets with the maximum packet number length. + // Always send long header packets with the maximum packet number length. // This simplifies retransmissions: Since the header can't get any larger, // we don't need to split CRYPTO frames. hdr.PacketNumberLen = protocol.PacketNumberLen4 diff --git a/packet_packer_test.go b/packet_packer_test.go index 683badff..c4bb50c4 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -241,6 +241,39 @@ var _ = Describe("Packet packer", func() { }) }) + Context("packing 0-RTT packets", func() { + BeforeEach(func() { + packer.perspective = protocol.PerspectiveClient + sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable).AnyTimes() + initialStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() + handshakeStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).AnyTimes() + }) + + It("packs a 0-RTT packet", func() { + sealingManager.EXPECT().Get0RTTSealer().Return(sealer, nil).AnyTimes() + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) + cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{ByteOffset: 0x1337}} + framer.EXPECT().AppendControlFrames(nil, gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + return append(frames, cf), cf.Length(packer.version) + }) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + return frames, 0 + }) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.frames).To(Equal([]ackhandler.Frame{cf})) + }) + }) + Context("packing normal packets", func() { BeforeEach(func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes()