From 5f14b03135033777d3b5f535cdde6bbc84eecb86 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 Nov 2019 10:41:53 +0700 Subject: [PATCH] refactor packet packer --- packet_packer.go | 100 +++++++++++++++++++++++++++--------------- packet_packer_test.go | 54 +++++++++-------------- 2 files changed, 84 insertions(+), 70 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 357e77dc..e2d06f33 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -134,8 +134,9 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager - // Once the handshake is confirmed, we only need to send 1-RTT packets. - handshakeConfirmed bool + // Once both Initial and Handshake keys are dropped, we only send 1-RTT packets. + droppedInitial bool + droppedHandshake bool initialStream cryptoStream handshakeStream cryptoStream @@ -183,6 +184,10 @@ func newPacketPacker( } } +func (p *packetPacker) handshakeConfirmed() bool { + return p.droppedInitial && p.droppedHandshake +} + // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { payload := payload{ @@ -219,7 +224,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { var encLevel protocol.EncryptionLevel var ack *wire.AckFrame - if !p.handshakeConfirmed { + if !p.handshakeConfirmed() { ack = p.acks.GetAckFrame(protocol.EncryptionInitial) if ack != nil { encLevel = protocol.EncryptionInitial @@ -255,7 +260,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - if !p.handshakeConfirmed { + if !p.handshakeConfirmed() { packet, err := p.maybePackCryptoPacket() if err != nil { return nil, err @@ -297,44 +302,67 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { - var s cryptoStream - var encLevel protocol.EncryptionLevel + packet, err := p.maybePackInitialPacket() + if err != nil { + return nil, err + } + if packet != nil { + return packet, nil + } + return p.maybePackHandshakePacket() +} - initialSealer, errInitialSealer := p.cryptoSetup.GetInitialSealer() - handshakeSealer, errHandshakeSealer := p.cryptoSetup.GetHandshakeSealer() - - if errInitialSealer == handshake.ErrKeysDropped && - errHandshakeSealer == handshake.ErrKeysDropped { - p.handshakeConfirmed = true +func (p *packetPacker) maybePackInitialPacket() (*packedPacket, error) { + sealer, err := p.cryptoSetup.GetInitialSealer() + if err == handshake.ErrKeysDropped { + p.droppedInitial = true + return nil, nil + } + if err != nil { + return nil, err } - hasData := p.initialStream.HasData() hasRetransmission := p.retransmissionQueue.HasInitialData() ack := p.acks.GetAckFrame(protocol.EncryptionInitial) - var sealer handshake.LongHeaderSealer - if hasData || hasRetransmission || ack != nil { - s = p.initialStream - encLevel = protocol.EncryptionInitial - sealer = initialSealer - if errInitialSealer != nil { - return nil, fmt.Errorf("PacketPacker BUG: no Initial sealer: %s", errInitialSealer) - } - } else { - hasData = p.handshakeStream.HasData() - hasRetransmission = p.retransmissionQueue.HasHandshakeData() - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) - if hasData || hasRetransmission || ack != nil { - s = p.handshakeStream - encLevel = protocol.EncryptionHandshake - sealer = handshakeSealer - if errHandshakeSealer != nil { - return nil, fmt.Errorf("PacketPacker BUG: no Handshake sealer: %s", errHandshakeSealer) - } - } - } - if s == nil { + if !p.initialStream.HasData() && !hasRetransmission && ack == nil { + // nothing to send return nil, nil } + return p.packCryptoPacket(protocol.EncryptionInitial, sealer, ack, hasRetransmission) +} + +func (p *packetPacker) maybePackHandshakePacket() (*packedPacket, error) { + sealer, err := p.cryptoSetup.GetHandshakeSealer() + if err == handshake.ErrKeysDropped { + p.droppedHandshake = true + return nil, nil + } + if err == handshake.ErrKeysNotYetAvailable { + return nil, nil + } + if err != nil { + return nil, err + } + + hasRetransmission := p.retransmissionQueue.HasHandshakeData() + ack := p.acks.GetAckFrame(protocol.EncryptionHandshake) + if !p.handshakeStream.HasData() && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + return p.packCryptoPacket(protocol.EncryptionHandshake, sealer, ack, hasRetransmission) +} + +func (p *packetPacker) packCryptoPacket( + encLevel protocol.EncryptionLevel, + sealer handshake.LongHeaderSealer, + ack *wire.AckFrame, + hasRetransmission bool, +) (*packedPacket, error) { + s := p.initialStream + if encLevel == protocol.EncryptionHandshake { + s = p.handshakeStream + } var payload payload if ack != nil { @@ -360,7 +388,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) payload.length += f.Length(p.version) } - } else if hasData { + } else if s.HasData() { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) diff --git a/packet_packer_test.go b/packet_packer_test.go index d8a95002..5ce161a0 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "errors" "math/rand" "net" "time" @@ -580,25 +579,24 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).AnyTimes() initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkLength(p.raw) }) - It("packs a maximum size crypto packet", func() { + It("packs a maximum size Handshake packet", func() { var f *wire.CryptoFrame pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) + sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData().Return(true) + handshakeStream.EXPECT().HasData().Return(true).Times(2) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { f = &wire.CryptoFrame{Offset: 0x1337} f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version)-1)) @@ -620,7 +618,6 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() p, err := packer.PackPacket() @@ -634,9 +631,8 @@ var _ = Describe("Packet packer", func() { It("sends an Initial packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) - initialStream.EXPECT().HasData() + initialStream.EXPECT().HasData().Times(2) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() @@ -644,13 +640,24 @@ var _ = Describe("Packet packer", func() { Expect(p.ack).To(Equal(ack)) }) + It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { + sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + initialStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + It("sends a Handshake packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() - sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) + handshakeStream.EXPECT().HasData().Times(2) + sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) @@ -666,9 +673,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient packet, err := packer.PackPacket() @@ -678,21 +684,6 @@ var _ = Describe("Packet packer", func() { Expect(packet.frames).To(HaveLen(1)) cf := packet.frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) - }) - - It("sets the correct length for an Initial packet", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ - Data: []byte("foobar"), - }) - packer.perspective = protocol.PerspectiveClient - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) checkLength(packet.raw) }) @@ -702,9 +693,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient @@ -718,11 +708,7 @@ var _ = Describe("Packet packer", func() { It("stops packing crypto packets when the keys are dropped", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)