diff --git a/connection.go b/connection.go index d161b183..ddc60183 100644 --- a/connection.go +++ b/connection.go @@ -1799,7 +1799,8 @@ func (s *connection) sendPackets() error { return err } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) + s.registerPackedShortHeaderPacket(p.Packet, now) + s.sendQueue.Send(buffer) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1827,7 +1828,8 @@ func (s *connection) sendPackets() error { } for { - sent, err := s.sendPacket(now) + buf := getPacketBuffer() + sent, err := s.appendPacket(buf, now) if err != nil || !sent { return err } @@ -1881,7 +1883,8 @@ func (s *connection) maybeSendAckOnlyPacket() error { return err } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) + s.registerPackedShortHeaderPacket(p.Packet, now) + s.sendQueue.Send(buffer) return nil } @@ -1927,27 +1930,27 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { return nil } -func (s *connection) sendPacket(now time.Time) (bool, error) { - p, buffer, err := s.packer.PackPacket(s.mtuDiscoverer.CurrentSize(), s.version) +func (s *connection) appendPacket(buf *packetBuffer, now time.Time) (bool, error) { + p, err := s.packer.AppendPacket(buf, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { return false, nil } return false, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) + s.registerPackedShortHeaderPacket(p.Packet, now) + s.sendQueue.Send(buf) return true, nil } -func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { +func (s *connection) registerPackedShortHeaderPacket(p *ackhandler.Packet, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } s.sentPacketHandler.SentPacket(p) s.connIDManager.SentPacket() - s.sendQueue.Send(buffer) } func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { diff --git a/connection_test.go b/connection_test.go index 454afe40..1b7a7bbe 100644 --- a/connection_test.go +++ b/connection_test.go @@ -53,10 +53,10 @@ var _ = Describe("Connection", func() { destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - getShortHeaderPacket := func(pn protocol.PacketNumber) (shortHeaderPacket, *packetBuffer) { + getShortHeaderPacket := func(pn protocol.PacketNumber) shortHeaderPacket { buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}}, buffer + return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}} } getCoalescedPacket := func(pn protocol.PacketNumber, isLongHeader bool) *coalescedPacket { @@ -613,9 +613,9 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() conn.sentPacketHandler = sph - p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + p := getShortHeaderPacket(1) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() runConn() conn.queueControlFrame(&wire.PingFrame{}) conn.scheduleSending() @@ -1208,9 +1208,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph runConn() - p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + p := getShortHeaderPacket(1) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) @@ -1219,7 +1219,7 @@ var _ = Describe("Connection", func() { PacketNumber: p.PacketNumber, PacketNumberLen: p.PacketNumberLen, KeyPhase: p.KeyPhase, - }, buffer.Len(), nil, []logging.Frame{}) + }, gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) @@ -1227,7 +1227,7 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true runConn() - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() @@ -1256,14 +1256,14 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) - p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + p := getShortHeaderPacket(1) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.connFlowController = fc runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), buffer.Len(), nil, []logging.Frame{}) + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) frames, _ := conn.framer.AppendControlFrames(nil, 1000, protocol.Version1) @@ -1396,10 +1396,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - p, buffer = getShortHeaderPacket(11) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(10) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + p = getShortHeaderPacket(11) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Times(2) go func() { @@ -1414,9 +1414,9 @@ var _ = Describe("Connection", func() { It("sends multiple packets, when the pacer allows immediate sending", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) - p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p := getShortHeaderPacket(10) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1432,8 +1432,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) - p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(10) + packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(p, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) @@ -1452,8 +1452,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - p, buffer := getShortHeaderPacket(100) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(100) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1467,16 +1467,16 @@ var _ = Describe("Connection", func() { It("paces packets", func() { pacingDelay := scaleDuration(100 * time.Millisecond) - p1, buffer1 := getShortHeaderPacket(100) - p2, buffer2 := getShortHeaderPacket(101) + p1 := getShortHeaderPacket(100) + p2 := getShortHeaderPacket(101) gomock.InOrder( sph.EXPECT().SendMode().Return(ackhandler.SendAny), - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p1, buffer1, nil), + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p1, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode().Return(ackhandler.SendAny), - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p2, buffer2, nil), + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p2, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), @@ -1501,8 +1501,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { - p, buffer := getShortHeaderPacket(pn) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(pn) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() @@ -1532,9 +1532,9 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p := getShortHeaderPacket(1000) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) @@ -1555,9 +1555,9 @@ var _ = Describe("Connection", func() { conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p := getShortHeaderPacket(1000) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) conn.scheduleSending() @@ -1569,8 +1569,8 @@ var _ = Describe("Connection", func() { It("stops sending when the send queue is full", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) - p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(1000) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) @@ -1590,9 +1590,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - p, buffer = getShortHeaderPacket(1001) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p = getShortHeaderPacket(1001) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1605,7 +1605,7 @@ var _ = Describe("Connection", func() { It("doesn't set a pacing timer when there is no data to send", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1629,8 +1629,8 @@ var _ = Describe("Connection", func() { mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) - p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any(), conn.version).Return(p, buffer, nil) + p := getShortHeaderPacket(1) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any(), conn.version).Return(p, getPacketBuffer(), nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1675,9 +1675,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph - p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p := getShortHeaderPacket(1) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) go func() { defer GinkgoRecover() @@ -1695,9 +1695,9 @@ var _ = Describe("Connection", func() { }) It("sets the timer to the ack timer", func() { - p, buffer := getShortHeaderPacket(1234) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + p := getShortHeaderPacket(1234) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1912,14 +1912,14 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).DoAndReturn(func(_ protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *packetBuffer, _ protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, v) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) - return shortHeaderPacket{Packet: &ackhandler.Packet{}}, getPacketBuffer(), nil + return shortHeaderPacket{Packet: &ackhandler.Packet{}}, nil }) - packer.EXPECT().PackPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() diff --git a/mock_packer_test.go b/mock_packer_test.go index 6c2e03ed..f539d255 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -37,6 +37,21 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder { return m.recorder } +// AppendPacket mocks base method. +func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (shortHeaderPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppendPacket", arg0, arg1, arg2) + ret0, _ := ret[0].(shortHeaderPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AppendPacket indicates an expected call of AppendPacket. +func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, arg1, arg2) +} + // MaybePackProbePacket mocks base method. func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { m.ctrl.T.Helper() @@ -129,22 +144,6 @@ func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2, arg3 inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), arg0, arg1, arg2, arg3) } -// PackPacket mocks base method. -func (m *MockPacker) PackPacket(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackPacket", arg0, arg1) - ret0, _ := ret[0].(shortHeaderPacket) - ret1, _ := ret[1].(*packetBuffer) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// PackPacket indicates an expected call of PackPacket. -func (mr *MockPackerMockRecorder) PackPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), arg0, arg1) -} - // SetToken mocks base method. func (m *MockPacker) SetToken(arg0 []byte) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index e6d8df6b..6f6e6d9c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -17,7 +17,7 @@ var errNothingToPack = errors.New("nothing to pack") type packer interface { PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - PackPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) @@ -472,32 +472,34 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { - return p.packPacket(true, maxPacketSize, v) + buf := getPacketBuffer() + packet, err := p.appendPacket(buf, true, maxPacketSize, v) + return packet, buf, err } -// PackPacket packs a packet in the application data packet number space. +// AppendPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { - return p.packPacket(false, maxPacketSize, v) +func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { + return p.appendPacket(buf, false, maxPacketSize, v) } -func (p *packetPacker) packPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { - return shortHeaderPacket{}, nil, err + return shortHeaderPacket{}, err } pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) connID := p.getDestConnID() hdrLen := wire.ShortHeaderLen(connID, pnLen) pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v) if pl.length == 0 { - return shortHeaderPacket{}, nil, errNothingToPack + return shortHeaderPacket{}, errNothingToPack } kp := sealer.KeyPhase() - buffer := getPacketBuffer() - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) + + ap, ack, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) if err != nil { - return shortHeaderPacket{}, nil, err + return shortHeaderPacket{}, err } return shortHeaderPacket{ Packet: ap, @@ -505,7 +507,7 @@ func (p *packetPacker) packPacket(onlyAck bool, maxPacketSize protocol.ByteCount Ack: ack, PacketNumberLen: pnLen, KeyPhase: kp, - }, buffer, nil + }, nil } func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { diff --git a/packet_packer_test.go b/packet_packer_test.go index 5db6d94f..8f09dc9b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -493,11 +493,11 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData() - _, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + _, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) }) - It("packs single packets", func() { + It("appends single packets", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) @@ -509,7 +509,9 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - p, buffer, err := packer.PackPacket(maxPacketSize, protocol.Version1) + buffer := getPacketBuffer() + buffer.Data = append(buffer.Data, []byte("foobar")...) + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b, err := f.Append(nil, protocol.Version1) @@ -517,6 +519,7 @@ var _ = Describe("Packet packer", func() { Expect(p.Frames).To(BeEmpty()) Expect(p.StreamFrames).To(HaveLen(1)) Expect(p.StreamFrames[0].Frame.StreamID).To(Equal(f.StreamID)) + Expect(buffer.Data[:6]).To(Equal([]byte("foobar"))) // make sure the packet was actually appended Expect(buffer.Data).To(ContainSubstring(string(b))) }) @@ -527,7 +530,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) @@ -545,7 +548,8 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, buffer, err := packer.PackPacket(maxPacketSize, protocol.Version1) + buffer := getPacketBuffer() + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(2)) @@ -574,7 +578,8 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, buffer, err := packer.PackPacket(maxPacketSize, protocol.Version1) + buffer := getPacketBuffer() + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(1)) @@ -602,7 +607,8 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, buffer, err := packer.PackPacket(maxPacketSize, protocol.Version1) + buffer := getPacketBuffer() + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -628,7 +634,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + _, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) }) @@ -685,7 +691,8 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - _, buffer, err := packer.PackPacket(maxPacketSize, protocol.Version1) + buffer := getPacketBuffer() + _, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added buffer.Data = buffer.Data[:buffer.Len()-protocol.ByteCount(sealer.Overhead())] @@ -734,7 +741,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f1}, ackhandler.StreamFrame{Frame: f2}, ackhandler.StreamFrame{Frame: f3}) - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(BeEmpty()) @@ -754,7 +761,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -771,7 +778,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) var hasPing bool @@ -790,7 +797,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err = packer.PackPacket(maxPacketSize, protocol.Version1) + p, err = packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -806,7 +813,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - _, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + _, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) // now add some frame to send expectAppendControlFrames() @@ -817,7 +824,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).To(Equal(ack)) var hasPing bool @@ -839,7 +846,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, _, err := packer.PackPacket(maxPacketSize, protocol.Version1) + p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Frames).ToNot(ContainElement(&wire.PingFrame{}))