From a4b4d520632c340a15a6e64198689a3c715c84df Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Feb 2020 14:15:50 +0800 Subject: [PATCH] refactor packing of packets before and after the handshake is confirmed --- mock_packer_test.go | 23 +++++++++++--- packet_packer.go | 46 +++++++++++----------------- packet_packer_test.go | 71 +++++++++++++------------------------------ session.go | 14 +++++++-- session_test.go | 2 +- 5 files changed, 70 insertions(+), 86 deletions(-) diff --git a/mock_packer_test.go b/mock_packer_test.go index 6425c1ea..e257b877 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -49,18 +49,18 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g } // MaybePackAckPacket mocks base method -func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) { +func (m *MockPacker) MaybePackAckPacket(arg0 bool) (*packedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackAckPacket") + ret := m.ctrl.Call(m, "MaybePackAckPacket", arg0) ret0, _ := ret[0].(*packedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // MaybePackAckPacket indicates an expected call of MaybePackAckPacket -func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackAckPacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), arg0) } // MaybePackProbePacket mocks base method @@ -78,6 +78,21 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) } +// PackAppDataPacket mocks base method +func (m *MockPacker) PackAppDataPacket() (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackAppDataPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackAppDataPacket indicates an expected call of PackAppDataPacket +func (mr *MockPackerMockRecorder) PackAppDataPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAppDataPacket", reflect.TypeOf((*MockPacker)(nil).PackAppDataPacket)) +} + // PackConnectionClose mocks base method func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index cf010b88..c96307cb 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -16,8 +16,9 @@ import ( type packer interface { PackPacket() (*packedPacket, error) + PackAppDataPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) - MaybePackAckPacket() (*packedPacket, error) + MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) HandleTransportParameters(*handshake.TransportParameters) @@ -138,10 +139,6 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager - // Once both Initial and Handshake keys are dropped, we only send 1-RTT packets. - droppedInitial bool - droppedHandshake bool - initialStream cryptoStream handshakeStream cryptoStream @@ -188,10 +185,6 @@ func newPacketPacker( } } -func (p *packetPacker) handshakeConfirmed() bool { - return p.droppedInitial && p.droppedHandshake -} - // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { payload := payload{ @@ -225,10 +218,10 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } -func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { +func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { var encLevel protocol.EncryptionLevel var ack *wire.AckFrame - if !p.handshakeConfirmed() { + if !handshakeConfirmed { ack = p.acks.GetAckFrame(protocol.EncryptionInitial) if ack != nil { encLevel = protocol.EncryptionInitial @@ -261,38 +254,33 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } -// PackPacket packs a new packet -// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise +// PackPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. func (p *packetPacker) PackPacket() (*packedPacket, error) { - if !p.handshakeConfirmed() { - packet, err := p.maybePackCryptoPacket() - if err != nil { - return nil, err - } - if packet != nil { - return packet, nil - } + packet, err := p.maybePackCryptoPacket() + if err != nil || packet != nil { + return packet, err } + return p.maybePackAppDataPacket() +} +// PackAppDataPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) PackAppDataPacket() (*packedPacket, error) { return p.maybePackAppDataPacket() } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { // Try packing an Initial packet. packet, err := p.maybePackInitialPacket() - if err == handshake.ErrKeysDropped { - p.droppedInitial = true - } else if err != nil || packet != nil { + if (err != nil && err != handshake.ErrKeysDropped) || packet != nil { return packet, err } // No Initial was packed. Try packing a Handshake packet. packet, err = p.maybePackHandshakePacket() - if err == handshake.ErrKeysDropped { - p.droppedHandshake = true - return nil, nil - } - if err == handshake.ErrKeysNotYetAvailable { + if err == handshake.ErrKeysDropped || err == handshake.ErrKeysNotYetAvailable { return nil, nil } return packet, err diff --git a/packet_packer_test.go b/packet_packer_test.go index 547c0640..06ce4905 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -206,7 +206,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -218,7 +218,7 @@ var _ = Describe("Packet packer", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(false) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) @@ -230,10 +230,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err := packer.MaybePackAckPacket() + p, err := packer.MaybePackAckPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) @@ -276,8 +274,6 @@ var _ = Describe("Packet packer", func() { Context("packing normal packets", func() { BeforeEach(func() { - sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, nil).AnyTimes() initialStream.EXPECT().HasData().AnyTimes() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes() @@ -291,7 +287,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -307,7 +303,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} @@ -326,7 +322,7 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte("foobar"), }}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) @@ -339,7 +335,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.ack).To(Equal(ack)) @@ -371,7 +367,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal(frames)) @@ -393,7 +389,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -409,7 +405,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackPacket() + packet, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] @@ -458,7 +454,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) @@ -476,7 +472,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -492,7 +488,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(ackhandler.Frame{Frame: &wire.PingFrame{}})) @@ -503,7 +499,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err = packer.PackPacket() + p, err = packer.PackAppDataPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -518,7 +514,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) // now add some frame to send @@ -529,7 +525,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err = packer.PackPacket() + p, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) @@ -543,7 +539,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackPacket() + p, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) @@ -561,7 +557,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // now reduce the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -572,7 +568,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) @@ -586,7 +582,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) // now try to increase the maxPacketSize packer.HandleTransportParameters(&handshake.TransportParameters{ @@ -597,7 +593,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackAppDataPacket() Expect(err).ToNot(HaveOccurred()) }) }) @@ -737,31 +733,6 @@ var _ = Describe("Packet packer", func() { Expect(packet.ack).To(Equal(ack)) Expect(packet.frames).To(HaveLen(1)) }) - - It("stops packing crypto packets when the keys are dropped", func() { - sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) - expectAppendStreamFrames() - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - - // now the packer should have realized that the handshake is confirmed - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames(ackhandler.Frame{Frame: &wire.PingFrame{}}) - expectAppendStreamFrames() - packet, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(packet).ToNot(BeNil()) - }) }) Context("packing probe packets", func() { diff --git a/session.go b/session.go index 1b065c30..e1b51152 100644 --- a/session.go +++ b/session.go @@ -164,6 +164,7 @@ type session struct { earlySessionReadyChan chan struct{} handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool + handshakeConfirmed bool receivedRetry bool receivedFirstPacket bool @@ -1139,6 +1140,9 @@ func (s *session) handleCloseError(closeErr closeError) { } func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { + if encLevel == protocol.EncryptionHandshake { + s.handshakeConfirmed = true + } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) } @@ -1247,7 +1251,7 @@ sendLoop: } func (s *session) maybeSendAckOnlyPacket() error { - packet, err := s.packer.MaybePackAckPacket() + packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) if err != nil { return err } @@ -1305,7 +1309,13 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() - packet, err := s.packer.PackPacket() + var packet *packedPacket + var err error + if !s.handshakeConfirmed { + packet, err = s.packer.PackPacket() + } else { + packet, err = s.packer.PackAppDataPacket() + } if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 6e744ded..64aebdc0 100644 --- a/session_test.go +++ b/session_test.go @@ -905,7 +905,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAck) sph.EXPECT().ShouldSendNumPackets().Return(1000) - packer.EXPECT().MaybePackAckPacket() + packer.EXPECT().MaybePackAckPacket(false) sess.sentPacketHandler = sph Expect(sess.sendPackets()).To(Succeed()) })