diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 509c13c3..6d584c7d 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -552,6 +552,9 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() + if h.initialSealer == nil { + return nil, ErrKeysDropped + } return h.initialSealer, nil } @@ -560,6 +563,9 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { defer h.mutex.Unlock() if h.handshakeSealer == nil { + if h.initialSealer == nil { + return nil, ErrKeysDropped + } return nil, errors.New("CryptoSetup: no sealer with encryption level Handshake") } return h.handshakeSealer, nil diff --git a/packet_packer.go b/packet_packer.go index 61e7d857..9c437445 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -119,6 +119,9 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager + // Once the handshake is confirmed, we only need to send 1-RTT packets. + handshakeConfirmed bool + initialStream cryptoStream handshakeStream cryptoStream @@ -313,12 +316,14 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - packet, err := p.maybePackCryptoPacket() - if err != nil { - return nil, err - } - if packet != nil { - return packet, nil + if !p.handshakeConfirmed { + packet, err := p.maybePackCryptoPacket() + if err != nil { + return nil, err + } + if packet != nil { + return packet, nil + } } sealer, err := p.cryptoSetup.Get1RTTSealer() @@ -359,26 +364,36 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { var s cryptoStream var encLevel protocol.EncryptionLevel + initialSealer, errInitialSealer := p.cryptoSetup.GetInitialSealer() + handshakeSealer, errHandshakeSealer := p.cryptoSetup.GetHandshakeSealer() + + if errInitialSealer == handshake.ErrKeysDropped && + errHandshakeSealer == handshake.ErrKeysDropped { + p.handshakeConfirmed = true + } + hasData := p.initialStream.HasData() ack := p.acks.GetAckFrame(protocol.EncryptionInitial) var sealer handshake.LongHeaderSealer - var err error if hasData || ack != nil { s = p.initialStream encLevel = protocol.EncryptionInitial - sealer, err = p.cryptoSetup.GetInitialSealer() + sealer = initialSealer + if errInitialSealer != nil { + return nil, fmt.Errorf("PacketPacker BUG: no Initial sealer: %s", errInitialSealer) + } } else { hasData = p.handshakeStream.HasData() ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) if hasData || ack != nil { s = p.handshakeStream encLevel = protocol.EncryptionHandshake - sealer, err = p.cryptoSetup.GetHandshakeSealer() + sealer = handshakeSealer + if errHandshakeSealer != nil { + return nil, fmt.Errorf("PacketPacker BUG: no Handshake sealer: %s", errHandshakeSealer) + } } } - if err != nil { - return nil, err - } if s == nil { return nil, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index 3904201f..44605369 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "math/rand" "net" @@ -176,6 +177,8 @@ var _ = Describe("Packet packer", func() { } }), ) + sealingManager.EXPECT().GetInitialSealer().Return(nil, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) @@ -211,6 +214,8 @@ var _ = Describe("Packet packer", func() { Context("packing normal packets", func() { BeforeEach(func() { + sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() initialStream.EXPECT().HasData().AnyTimes() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes() @@ -327,6 +332,41 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) }) + It("pads if payload length + packet number length is smaller than 4", func() { + f := &wire.StreamFrame{ + StreamID: 0x10, // small stream ID, such that only a single byte is consumed + FinBit: true, + } + Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}, f.Length(packer.version)) + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // cut off the tag that the mock sealer added + packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] + hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID)) + Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(packet.raw) + extHdr, err := hdr.ParseExtended(r, packer.version) + Expect(err).ToNot(HaveOccurred()) + Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) + // the first byte of the payload should be a PADDING frame... + firstPayloadByte, err := r.ReadByte() + Expect(err).ToNot(HaveOccurred()) + Expect(firstPayloadByte).To(Equal(byte(0))) + // ... followed by the STREAM frame + frameParser := wire.NewFrameParser(packer.version) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + Expect(r.Len()).To(BeZero()) + }) + Context("packing ACK packets", func() { It("doesn't pack a packet if there's no ACK to send", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) @@ -719,6 +759,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkLength(p.raw) @@ -728,6 +769,7 @@ var _ = Describe("Packet packer", func() { var f *wire.CryptoFrame pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) @@ -753,6 +795,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData() sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() @@ -766,6 +809,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData() + sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) @@ -781,6 +825,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) @@ -794,49 +839,11 @@ var _ = Describe("Packet packer", func() { Expect(cf.Data).To(Equal([]byte("foobar"))) }) - It("pads if payload length + packet number length is smaller than 4", func() { - f := &wire.StreamFrame{ - StreamID: 0x10, // small stream ID, such that only a single byte is consumed - FinBit: true, - } - Expect(f.Length(packer.version)).To(BeEquivalentTo(2)) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Return([]wire.Frame{f}, f.Length(packer.version)) - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // cut off the tag that the mock sealer added - packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] - hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID)) - Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(packet.raw) - extHdr, err := hdr.ParseExtended(r, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) - Expect(r.Len()).To(Equal(4 - 1 /* packet number length */)) - // the first byte of the payload should be a PADDING frame... - firstPayloadByte, err := r.ReadByte() - Expect(err).ToNot(HaveOccurred()) - Expect(firstPayloadByte).To(Equal(byte(0))) - // ... followed by the STREAM frame - frameParser := wire.NewFrameParser(packer.version) - frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) - Expect(r.Len()).To(BeZero()) - }) - It("sets the correct length for an Initial packet", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ @@ -854,6 +861,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) @@ -866,6 +874,35 @@ var _ = Describe("Packet packer", func() { Expect(packet.frames).To(HaveLen(1)) }) + It("stops packing crypto packets when the keys are dropped", func() { + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData() + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + expectAppendControlFrames(&wire.PingFrame{}) + expectAppendStreamFrames() + packet, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + + // now the packer should have realized that the handshake is confirmed + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + expectAppendControlFrames(&wire.PingFrame{}) + expectAppendStreamFrames() + packet, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + }) + Context("retransmitions", func() { cf := &wire.CryptoFrame{Data: []byte("foo")}