diff --git a/connection.go b/connection.go index 5177eff22..e98a13449 100644 --- a/connection.go +++ b/connection.go @@ -297,6 +297,7 @@ var newConnection = func( s.tracer, s.logger, ) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) initialStream := newCryptoStream() handshakeStream := newCryptoStream() params := &wire.TransportParameters{ @@ -353,7 +354,7 @@ var newConnection = func( s.version, ) s.cryptoStreamHandler = cs - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) return s @@ -418,6 +419,7 @@ var newClientConnection = func( s.tracer, s.logger, ) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) initialStream := newCryptoStream() handshakeStream := newCryptoStream() params := &wire.TransportParameters{ @@ -471,7 +473,7 @@ var newClientConnection = func( s.cryptoStreamHandler = cs s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { s.tokenStoreKey = tlsConf.ServerName } else { @@ -523,14 +525,6 @@ func (s *connection) preSetup() { s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) - s.mtuDiscoverer = newMTUDiscoverer( - s.rttStats, - getMaxPacketSize(s.conn.RemoteAddr()), - func(size protocol.ByteCount) { - s.sentPacketHandler.SetMaxDatagramSize(size) - s.packer.SetMaxPacketSize(size) - }, - ) s.connState.Version = s.version } @@ -1725,7 +1719,6 @@ func (s *connection) applyTransportParameters() { s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) - s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.rttStats.SetMaxAckDelay(params.MaxAckDelay) @@ -1808,7 +1801,7 @@ func (s *connection) sendPackets() error { func (s *connection) maybeSendAckOnlyPacket() error { if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket(true, s.version) + packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1820,7 +1813,7 @@ func (s *connection) maybeSendAckOnlyPacket() error { } now := time.Now() - p, buffer, err := s.packer.PackPacket(true, now, s.version) + p, buffer, err := s.packer.PackPacket(true, now, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { return nil @@ -1841,7 +1834,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1862,7 +1855,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { panic("unexpected encryption level") } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1882,7 +1875,7 @@ func (s *connection) sendPacket() (bool, error) { now := time.Now() if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket(false, s.version) + packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil || packet == nil { return false, err } @@ -1899,7 +1892,7 @@ func (s *connection) sendPacket() (bool, error) { s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } - p, buffer, err := s.packer.PackPacket(false, now, s.version) + p, buffer, err := s.packer.PackPacket(false, now, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { return false, nil @@ -1945,14 +1938,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { var transportErr *qerr.TransportError var applicationErr *qerr.ApplicationError if errors.As(e, &transportErr) { - packet, err = s.packer.PackConnectionClose(transportErr, s.version) + packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version) } else if errors.As(e, &applicationErr) { - packet, err = s.packer.PackApplicationClose(applicationErr, s.version) + packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version) } else { packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), - }, s.version) + }, s.mtuDiscoverer.CurrentSize(), s.version) } if err != nil { return nil, err diff --git a/connection_test.go b/connection_test.go index ea79c9a82..d66b33e32 100644 --- a/connection_test.go +++ b/connection_test.go @@ -442,7 +442,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("connection close")...) - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil @@ -467,7 +467,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -486,7 +486,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(expectedErr) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(expectedErr, conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), @@ -507,7 +507,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(expectedErr) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(expectedErr, conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), @@ -547,7 +547,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -615,8 +615,8 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() conn.sentPacketHandler = sph p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() runConn() conn.queueControlFrame(&wire.PingFrame{}) conn.scheduleSending() @@ -805,7 +805,7 @@ var _ = Describe("Connection", func() { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrDecryptionFailed) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -842,7 +842,7 @@ var _ = Describe("Connection", func() { }).Times(3) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) - packer.EXPECT().PackCoalescedPacket(false, conn.version) // only expect a single call + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version) // only expect a single call for i := 0; i < 3; i++ { conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar"))) @@ -858,7 +858,7 @@ var _ = Describe("Connection", func() { // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -877,7 +877,7 @@ var _ = Describe("Connection", func() { }).Times(3) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) - packer.EXPECT().PackCoalescedPacket(false, conn.version).Times(3) // only expect a single call + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Times(3) // only expect a single call for i := 0; i < 3; i++ { conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar"))) @@ -893,7 +893,7 @@ var _ = Describe("Connection", func() { // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -906,7 +906,7 @@ var _ = Describe("Connection", func() { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -942,7 +942,7 @@ var _ = Describe("Connection", func() { conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -954,7 +954,7 @@ var _ = Describe("Connection", func() { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -1179,7 +1179,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1211,8 +1211,8 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) @@ -1229,7 +1229,7 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true runConn() - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() @@ -1241,7 +1241,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAck) done := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(true, conn.version).Do(func(bool, protocol.VersionNumber) { close(done) }) + packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) conn.sentPacketHandler = sph runConn() conn.scheduleSending() @@ -1261,8 +1261,8 @@ var _ = Describe("Connection", func() { fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() conn.connFlowController = fc runConn() sent := make(chan struct{}) @@ -1315,7 +1315,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) - packer.EXPECT().MaybePackProbePacket(encLevel, conn.version).Return(p, nil) + packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) @@ -1340,7 +1340,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel).Return(false) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) - packer.EXPECT().MaybePackProbePacket(encLevel, conn.version).Return(p, nil) + packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) @@ -1384,7 +1384,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1402,9 +1402,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) p, buffer = getShortHeaderPacket(11) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Times(2) go func() { @@ -1421,8 +1421,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1440,7 +1440,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny) p, buffer := getShortHeaderPacket(10) - packer.EXPECT().PackPacket(true, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(true, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) @@ -1461,7 +1461,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) p, buffer := getShortHeaderPacket(100) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1480,12 +1480,12 @@ var _ = Describe("Connection", func() { p2, buffer2 := getShortHeaderPacket(101) gomock.InOrder( sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p1, buffer1, nil), + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p1, buffer1, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p2, buffer2, nil), + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p2, buffer2, nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), @@ -1512,7 +1512,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { p, buffer := getShortHeaderPacket(pn) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() @@ -1544,8 +1544,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) @@ -1568,8 +1568,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) conn.scheduleSending() @@ -1583,7 +1583,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny) p, buffer := getShortHeaderPacket(1000) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) @@ -1605,8 +1605,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() p, buffer = getShortHeaderPacket(1001) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1620,7 +1620,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1654,6 +1654,7 @@ var _ = Describe("Connection", func() { }() conn.scheduleSending() Eventually(written).Should(Receive()) + mtuDiscoverer.EXPECT().CurrentSize().Return(protocol.ByteCount(1234)) }) }) @@ -1672,7 +1673,7 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sender.EXPECT().Close() @@ -1691,8 +1692,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph p, buffer := getShortHeaderPacket(1) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) go func() { defer GinkgoRecover() @@ -1711,8 +1712,8 @@ var _ = Describe("Connection", func() { It("sets the timer to the ack timer", func() { p, buffer := getShortHeaderPacket(1234) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(p, buffer, nil) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(p, buffer, nil) + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1746,7 +1747,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket(false, conn.version).Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Return(&coalescedPacket{ buffer: buffer, longHdrPackets: []*longHeaderPacket{ { @@ -1765,7 +1766,7 @@ var _ = Describe("Connection", func() { }, }, }, nil) - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1806,7 +1807,7 @@ var _ = Describe("Connection", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1816,7 +1817,7 @@ var _ = Describe("Connection", func() { }) It("cancels the HandshakeComplete context when the handshake completes", func() { - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph @@ -1841,7 +1842,7 @@ var _ = Describe("Connection", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1852,7 +1853,7 @@ var _ = Describe("Connection", func() { It("sends a connection ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() finishHandshake := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) go func() { @@ -1886,7 +1887,7 @@ var _ = Describe("Connection", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1896,10 +1897,10 @@ var _ = Describe("Connection", func() { }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1929,14 +1930,14 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).DoAndReturn(func(_ bool, _ time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ bool, _ time.Time, _ protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, v) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) return shortHeaderPacket{Packet: &ackhandler.Packet{}}, getPacketBuffer(), nil }) - packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() + packer.EXPECT().PackPacket(false, gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1950,7 +1951,7 @@ var _ = Describe("Connection", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1968,7 +1969,7 @@ var _ = Describe("Connection", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1992,7 +1993,7 @@ var _ = Describe("Connection", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2013,8 +2014,7 @@ var _ = Describe("Connection", func() { InitialSourceConnectionID: destConnID, } streamManager.EXPECT().UpdateLimits(params) - packer.EXPECT().HandleTransportParameters(params) - packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3) + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).MaxTimes(3) Expect(conn.earlyConnReady()).ToNot(BeClosed()) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) @@ -2025,7 +2025,6 @@ var _ = Describe("Connection", func() { Context("keep-alives", func() { setRemoteIdleTimeout := func(t time.Duration) { streamManager.EXPECT().UpdateLimits(gomock.Any()) - packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) conn.handleTransportParameters(&wire.TransportParameters{ MaxIdleTimeout: t, @@ -2051,7 +2050,7 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2064,7 +2063,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(5 * time.Second) conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, conn.version).Do(func(bool, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2077,7 +2076,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(time.Hour) conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, conn.version).Do(func(bool, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2165,7 +2164,7 @@ var _ = Describe("Connection", func() { conn.config.HandshakeIdleTimeout = 9999 * time.Second conn.config.MaxIdleTimeout = 9999 * time.Second conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeZero()) return &coalescedPacket{buffer: getPacketBuffer()}, nil }) @@ -2196,7 +2195,7 @@ var _ = Describe("Connection", func() { It("closes the connection due to the idle timeout before handshake", func() { conn.config.HandshakeIdleTimeout = 0 - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2222,7 +2221,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection due to the idle timeout after handshake", func() { - packer.EXPECT().PackCoalescedPacket(false, conn.version).AnyTimes() + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() gomock.InOrder( connRunner.EXPECT().Retire(clientDestConnID), connRunner.EXPECT().Remove(gomock.Any()), @@ -2263,7 +2262,7 @@ var _ = Describe("Connection", func() { }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2449,7 +2448,7 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{}) Expect(conn.handlePacketImpl(p)).To(BeTrue()) // make sure the go routine returns - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2708,9 +2707,9 @@ var _ = Describe("Client Connection", func() { if !closed { connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) if applicationClose { - packer.EXPECT().PackApplicationClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } else { - packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2739,8 +2738,7 @@ var _ = Describe("Client Connection", func() { StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, } - packer.EXPECT().HandleTransportParameters(gomock.Any()) - packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(1) + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) conn.handleHandshakeComplete() @@ -2761,7 +2759,6 @@ var _ = Describe("Client Connection", func() { InitialSourceConnectionID: destConnID, MaxIdleTimeout: 18 * time.Second, } - packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) conn.handleHandshakeComplete() diff --git a/mock_packer_test.go b/mock_packer_test.go index 60d084c1f..2b08fb158 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -12,7 +12,6 @@ import ( ackhandler "github.com/quic-go/quic-go/internal/ackhandler" protocol "github.com/quic-go/quic-go/internal/protocol" qerr "github.com/quic-go/quic-go/internal/qerr" - wire "github.com/quic-go/quic-go/internal/wire" ) // MockPacker is a mock of Packer interface. @@ -38,76 +37,64 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder { return m.recorder } -// HandleTransportParameters mocks base method. -func (m *MockPacker) HandleTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "HandleTransportParameters", arg0) -} - -// HandleTransportParameters indicates an expected call of HandleTransportParameters. -func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleTransportParameters", reflect.TypeOf((*MockPacker)(nil).HandleTransportParameters), arg0) -} - // MaybePackProbePacket mocks base method. -func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0, arg1) + ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2) } // PackApplicationClose mocks base method. -func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackApplicationClose", arg0, arg1) + ret := m.ctrl.Call(m, "PackApplicationClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackApplicationClose indicates an expected call of PackApplicationClose. -func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) } // PackCoalescedPacket mocks base method. -func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackCoalescedPacket", arg0, arg1) + ret := m.ctrl.Call(m, "PackCoalescedPacket", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1, arg2) } // PackConnectionClose mocks base method. -func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackConnectionClose", arg0, arg1) + ret := m.ctrl.Call(m, "PackConnectionClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackConnectionClose indicates an expected call of PackConnectionClose. -func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) } // PackMTUProbePacket mocks base method. @@ -127,9 +114,9 @@ func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2, arg3 inte } // PackPacket mocks base method. -func (m *MockPacker) PackPacket(arg0 bool, arg1 time.Time, arg2 protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (m *MockPacker) PackPacket(arg0 bool, arg1 time.Time, arg2 protocol.ByteCount, arg3 protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackPacket", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PackPacket", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(shortHeaderPacket) ret1, _ := ret[1].(*packetBuffer) ret2, _ := ret[2].(error) @@ -137,21 +124,9 @@ func (m *MockPacker) PackPacket(arg0 bool, arg1 time.Time, arg2 protocol.Version } // PackPacket indicates an expected call of PackPacket. -func (mr *MockPackerMockRecorder) PackPacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPackerMockRecorder) PackPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), arg0, arg1, arg2) -} - -// SetMaxPacketSize mocks base method. -func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetMaxPacketSize", arg0) -} - -// SetMaxPacketSize indicates an expected call of SetMaxPacketSize. -func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), arg0, arg1, arg2, arg3) } // SetToken mocks base method. diff --git a/mtu_discoverer.go b/mtu_discoverer.go index 6dbc22997..957abd37f 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -1,6 +1,7 @@ package quic import ( + "net" "time" "github.com/quic-go/quic-go/internal/ackhandler" @@ -26,6 +27,20 @@ const ( mtuProbeDelay = 5 ) +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if utils.IsIPv4(udpAddr.IP) { + maxSize = protocol.InitialPacketSizeIPv4 + } else { + maxSize = protocol.InitialPacketSizeIPv6 + } + } + return maxSize +} + type mtuFinder struct { lastProbeTime time.Time probeInFlight bool diff --git a/packet_packer.go b/packet_packer.go index 7e06040c5..b8549fe49 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -3,30 +3,25 @@ package quic import ( "errors" "fmt" - "net" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) var errNothingToPack = errors.New("nothing to pack") type packer interface { - PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) - PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.VersionNumber) (*coalescedPacket, error) - PackConnectionClose(*qerr.TransportError, protocol.VersionNumber) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError, protocol.VersionNumber) (*coalescedPacket, error) - - SetMaxPacketSize(protocol.ByteCount) + PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) + PackPacket(onlyAck bool, now time.Time, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - HandleTransportParameters(*wire.TransportParameters) SetToken([]byte) } @@ -118,20 +113,6 @@ func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQu return ap } -func getMaxPacketSize(addr net.Addr) protocol.ByteCount { - maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if utils.IsIPv4(udpAddr.IP) { - maxSize = protocol.InitialPacketSizeIPv4 - } else { - maxSize = protocol.InitialPacketSizeIPv6 - } - } - return maxSize -} - type packetNumberManager interface { PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber @@ -172,13 +153,23 @@ type packetPacker struct { datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue - maxPacketSize protocol.ByteCount numNonAckElicitingAcks int } var _ packer = &packetPacker{} -func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective) *packetPacker { +func newPacketPacker( + srcConnID protocol.ConnectionID, + getDestConnID func() protocol.ConnectionID, + initialStream, handshakeStream cryptoStream, + packetNumberManager packetNumberManager, + retransmissionQueue *retransmissionQueue, + cryptoSetup sealingManager, + framer frameSource, + acks ackFrameSource, + datagramQueue *datagramQueue, + perspective protocol.Perspective, +) *packetPacker { return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, @@ -191,23 +182,22 @@ func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() proto framer: framer, acks: acks, pnManager: packetNumberManager, - maxPacketSize: getMaxPacketSize(remoteAddr), } } // PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { var reason string // don't send details of crypto errors if !e.ErrorCode.IsCryptoError() { reason = e.ErrorMessage } - return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, v) + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v) } // PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, v protocol.VersionNumber) (*coalescedPacket, error) { - return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, v) +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { + return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) } func (p *packetPacker) packConnectionClose( @@ -215,6 +205,7 @@ func (p *packetPacker) packConnectionClose( errorCode uint64, frameType uint64, reason string, + maxPacketSize protocol.ByteCount, v protocol.VersionNumber, ) (*coalescedPacket, error) { var sealers [4]sealer @@ -296,10 +287,10 @@ func (p *packetPacker) packConnectionClose( } var paddingLen protocol.ByteCount if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size) + paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) } if encLevel == protocol.Encryption1RTT { - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false, v) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v) if err != nil { return nil, err } @@ -345,25 +336,21 @@ func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnL } // size is the expected size of the packet, if no padding was applied. -func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { +func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { return 0 } - if size >= p.maxPacketSize { + if currentSize >= maxPacketSize { return 0 } - return p.maxPacketSize - size + return maxPacketSize - currentSize } // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) { - maxPacketSize := p.maxPacketSize - if p.perspective == protocol.PerspectiveClient { - maxPacketSize = protocol.MinInitialPacketSize - } +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -445,7 +432,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { - padding := p.initialPaddingLen(initialPayload.frames, size) + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) if err != nil { return nil, err @@ -466,7 +453,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else if oneRTTPayload.length > 0 { - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false, v) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v) if err != nil { return nil, err } @@ -483,7 +470,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe // PackPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, nil, err @@ -491,13 +478,13 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.Versio pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) connID := p.getDestConnID() hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true, v) + pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v) if pl.length == 0 { return shortHeaderPacket{}, nil, errNothingToPack } kp := sealer.KeyPhase() buffer := getPacketBuffer() - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false, v) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) if err != nil { return shortHeaderPacket{}, nil, err } @@ -687,7 +674,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return pl } -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -697,13 +684,13 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) + pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) if pl.length == 0 { return nil, nil } buffer := getPacketBuffer() packet := &coalescedPacket{buffer: buffer} - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false, v) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v) if err != nil { return nil, err } @@ -728,14 +715,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) + hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) + hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) default: panic("unknown encryption level") } @@ -748,7 +735,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) var padding protocol.ByteCount if encLevel == protocol.EncryptionInitial { - padding = p.initialPaddingLen(pl.frames, size) + padding = p.initialPaddingLen(pl.frames, size, maxPacketSize) } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) @@ -773,7 +760,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) kp := s.KeyPhase() - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true, v) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v) if err != nil { return shortHeaderPacket{}, nil, err } @@ -854,7 +841,7 @@ func (p *packetPacker) appendShortHeaderPacket( pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, pl payload, - padding protocol.ByteCount, + padding, maxPacketSize protocol.ByteCount, sealer sealer, isMTUProbePacket bool, v protocol.VersionNumber, @@ -882,8 +869,8 @@ func (p *packetPacker) appendShortHeaderPacket( return nil, nil, err } if !isMTUProbePacket { - if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { - return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize { + return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize) } } raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) @@ -959,16 +946,3 @@ func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.Pack func (p *packetPacker) SetToken(token []byte) { p.token = token } - -// When a higher MTU is discovered, use it. -func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { - p.maxPacketSize = s -} - -// If the peer sets a max_packet_size that's smaller than the size we're currently using, -// we need to reduce the size of packets we send. -func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { - if params.MaxUDPPayloadSize != 0 { - p.maxPacketSize = utils.Min(p.maxPacketSize, params.MaxUDPPayloadSize) - } -} diff --git a/packet_packer_test.go b/packet_packer_test.go index 4fa2ce147..ffedeeb4f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -96,8 +96,7 @@ var _ = Describe("Packet packer", func() { pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) - packer = newPacketPacker(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), func() protocol.ConnectionID { return connID }, initialStream, handshakeStream, pnManager, retransmissionQueue, &net.TCPAddr{}, sealingManager, framer, ackFramer, datagramQueue, protocol.PerspectiveServer) - packer.maxPacketSize = maxPacketSize + packer = newPacketPacker(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), func() protocol.ConnectionID { return connID }, initialStream, handshakeStream, pnManager, retransmissionQueue, sealingManager, framer, ackFramer, datagramQueue, protocol.PerspectiveServer) }) Context("determining the maximum packet size", func() { @@ -169,7 +168,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(BeEmpty()) @@ -208,7 +207,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) - p, err := packer.PackCoalescedPacket(true, protocol.Version1) + p, err := packer.PackCoalescedPacket(true, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -220,14 +219,14 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.PackCoalescedPacket(true, protocol.Version1) + p, err := packer.PackCoalescedPacket(true, maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.longHdrPackets[0].ack).To(Equal(ack)) Expect(p.longHdrPackets[0].frames).To(BeEmpty()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) parsePacket(p.buffer.Data) }) @@ -237,7 +236,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.PackCoalescedPacket(true, protocol.Version1) + p, err := packer.PackCoalescedPacket(true, maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(HaveLen(1)) @@ -257,7 +256,7 @@ var _ = Describe("Packet packer", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - p, err := packer.PackCoalescedPacket(true, protocol.Version1) + p, err := packer.PackCoalescedPacket(true, maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(BeEmpty()) @@ -273,7 +272,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - p, buffer, err := packer.PackPacket(true, time.Now(), protocol.Version1) + p, buffer, err := packer.PackPacket(true, time.Now(), maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) @@ -309,7 +308,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.StreamFrame, _ protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { return frames, 0 }) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) @@ -319,7 +318,7 @@ var _ = Describe("Packet packer", func() { }) It("doesn't add an ACK-only 0-RTT packet", func() { // ACK frames cannot be sent in 0-RTT packets - p, err := packer.PackCoalescedPacket(true, protocol.Version1) + p, err := packer.PackCoalescedPacket(true, protocol.MaxByteCount, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -334,7 +333,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) quicErr := qerr.NewLocalCryptoError(0x42, "crypto error") quicErr.FrameType = 0x1234 - p, err := packer.PackConnectionClose(quicErr, protocol.Version1) + p, err := packer.PackConnectionClose(quicErr, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) @@ -357,7 +356,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.CryptoBufferExceeded, ErrorMessage: "test error", - }, protocol.Version1) + }, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(BeEmpty()) Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) @@ -381,7 +380,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackApplicationClose(&qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: "test error", - }, protocol.Version1) + }, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) @@ -423,7 +422,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackApplicationClose(&qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: "test error", - }, protocol.Version1) + }, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) @@ -458,7 +457,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackApplicationClose(&qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: "test error", - }, protocol.Version1) + }, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) @@ -494,7 +493,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData() - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + _, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) }) @@ -510,7 +509,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - p, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, buffer, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b, err := f.Append(nil, protocol.Version1) @@ -528,7 +527,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) @@ -546,7 +545,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, buffer, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(2)) @@ -575,7 +574,7 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, buffer, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(1)) @@ -603,7 +602,7 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, buffer, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -629,7 +628,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + _, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) }) @@ -643,7 +642,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) handshakeStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket(false, protocol.Version1) + packet, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.longHdrPackets).To(HaveLen(1)) @@ -686,7 +685,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - _, buffer, err := packer.PackPacket(false, time.Now(), protocol.Version1) + _, buffer, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added buffer.Data = buffer.Data[:buffer.Len()-protocol.ByteCount(sealer.Overhead())] @@ -735,7 +734,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f1}, ackhandler.StreamFrame{Frame: f2}, ackhandler.StreamFrame{Frame: f3}) - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(BeEmpty()) @@ -755,7 +754,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -772,7 +771,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) var hasPing bool @@ -791,7 +790,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, _, err = packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err = packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) @@ -807,7 +806,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + _, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).To(MatchError(errNothingToPack)) // now add some frame to send expectAppendControlFrames() @@ -818,7 +817,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).To(Equal(ack)) var hasPing bool @@ -840,93 +839,12 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) + p, _, err := packer.PackPacket(false, time.Now(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.Frames).ToNot(ContainElement(&wire.PingFrame{})) }) }) - - Context("handling transport parameters", func() { - It("lowers the maximum packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - // now reduce the maxPacketSize - packer.HandleTransportParameters(&wire.TransportParameters{ - MaxUDPPayloadSize: maxPacketSize - 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err = packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - }) - - It("doesn't increase the max packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - // now try to increase the maxPacketSize - packer.HandleTransportParameters(&wire.TransportParameters{ - MaxUDPPayloadSize: maxPacketSize + 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize)) - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err = packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - }) - }) - - Context("max packet size", func() { - It("increases the max packet size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil).Times(2) - framer.EXPECT().HasData().Return(true).Times(2) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Times(2) - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err := packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - // now reduce the maxPacketSize - const packetSizeIncrease = 50 - packer.SetMaxPacketSize(maxPacketSize + packetSizeIncrease) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(_ []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize + packetSizeIncrease)) - return nil, 0 - }) - expectAppendStreamFrames() - _, _, err = packer.PackPacket(false, time.Now(), protocol.Version1) - Expect(err).To(MatchError(errNothingToPack)) - }) - }) }) Context("packing crypto packets", func() { @@ -943,7 +861,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) parsePacket(p.buffer.Data) @@ -962,7 +880,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} }) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -992,11 +910,11 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(protocol.Version1)).To(Equal(size)) return f }) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) parsePacket(p.buffer.Data) }) @@ -1018,9 +936,9 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} }) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.shortHdrPacket).To(BeNil()) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1052,9 +970,9 @@ var _ = Describe("Packet packer", func() { }) handshakeStream.EXPECT().HasData() packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.shortHdrPacket).To(BeNil()) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1083,9 +1001,9 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().HasData() packer.retransmissionQueue.AddInitial(&wire.PingFrame{}) packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) @@ -1117,9 +1035,9 @@ var _ = Describe("Packet packer", func() { expectAppendStreamFrames() framer.EXPECT().HasData().Return(true) packer.retransmissionQueue.AddAppData(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) + Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) @@ -1153,7 +1071,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1190,7 +1108,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically("<", 100)) Expect(p.longHdrPackets).To(HaveLen(1)) @@ -1223,7 +1141,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(protocol.Version1)).To(Equal(s)) return f }) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.shortHdrPacket).To(BeNil()) @@ -1242,7 +1160,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) handshakeStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket(false, protocol.Version1) + packet, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.longHdrPackets).To(HaveLen(1)) @@ -1283,7 +1201,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData() - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1299,7 +1217,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].ack).To(Equal(ack)) @@ -1311,7 +1229,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -1327,7 +1245,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].ack).To(Equal(ack)) @@ -1350,7 +1268,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1375,7 +1293,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket(false, protocol.Version1) + p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.longHdrPackets[0].ack).To(Equal(ack)) @@ -1398,7 +1316,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.MaybePackProbePacket(protocol.EncryptionInitial, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.EncryptionInitial, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(HaveLen(1)) @@ -1420,7 +1338,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.MaybePackProbePacket(protocol.EncryptionInitial, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.EncryptionInitial, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -1444,7 +1362,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -1465,7 +1383,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -1489,7 +1407,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f}) - p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(BeEmpty()) @@ -1515,7 +1433,7 @@ var _ = Describe("Packet packer", func() { return append(fs, ackhandler.StreamFrame{Frame: sf}), sf.Length(v) }) - p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.Version1) + p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.longHdrPackets).To(BeEmpty()) @@ -1532,7 +1450,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData() - packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, protocol.Version1) + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(packet).To(BeNil()) })