diff --git a/connection.go b/connection.go index 14b632b93..50e39349f 100644 --- a/connection.go +++ b/connection.go @@ -670,7 +670,7 @@ runLoop: } else { idleTimeoutStartTime := c.idleTimeoutStartTime() if (!c.handshakeComplete && now.Sub(idleTimeoutStartTime) >= c.config.HandshakeIdleTimeout) || - (c.handshakeComplete && now.After(c.nextIdleTimeoutTime())) { + (c.handshakeComplete && !now.Before(c.nextIdleTimeoutTime())) { c.destroyImpl(qerr.ErrIdleTimeout) break runLoop } @@ -878,7 +878,7 @@ func (c *Conn) maybeResetTimer() { func (c *Conn) idleTimeoutStartTime() monotime.Time { startTime := c.lastPacketReceivedTime - if t := c.firstAckElicitingPacketAfterIdleSentTime; t.After(startTime) { + if t := c.firstAckElicitingPacketAfterIdleSentTime; !t.IsZero() && t.After(startTime) { startTime = t } return startTime diff --git a/connection_test.go b/connection_test.go index 65802f61b..bdf1f2726 100644 --- a/connection_test.go +++ b/connection_test.go @@ -21,6 +21,7 @@ import ( "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" @@ -818,54 +819,64 @@ func TestConnectionRemoteClose(t *testing.T) { } func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - &Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)}, - false, - connectionOptTracer(tr), - ) - tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}), - tracer.EXPECT().Close(), - ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - select { - case err := <-errChan: - require.ErrorIs(t, err, &IdleTimeoutError{}) - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Test(t, func(t *testing.T) { + const timeout = 7 * time.Second + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + &Config{HandshakeIdleTimeout: timeout}, + false, + connectionOptTracer(tr), + ) + tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}), + tracer.EXPECT().Close(), + ) + start := monotime.Now() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, &IdleTimeoutError{}) + require.Equal(t, timeout, monotime.Since(start)) + case <-time.After(timeout + time.Nanosecond): + t.Fatal("timeout") + } + }) } func TestConnectionHandshakeIdleTimeout(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - &Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)}, - false, - connectionOptTracer(tr), - func(c *Conn) { c.creationTime = monotime.Now().Add(-10 * time.Second) }, - ) - tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}), - tracer.EXPECT().Close(), - ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - select { - case err := <-errChan: - require.ErrorIs(t, err, &HandshakeTimeoutError{}) - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + &Config{HandshakeIdleTimeout: 7 * time.Second}, + false, + connectionOptTracer(tr), + func(c *Conn) { c.creationTime = monotime.Now().Add(-20 * time.Second) }, + ) + tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}), + tracer.EXPECT().Close(), + ) + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + select { + case err := <-errChan: + require.ErrorIs(t, err, &HandshakeTimeoutError{}) + case <-time.After(time.Second): + t.Fatal("timeout") + } + }) } func TestConnectionTransportParameters(t *testing.T) { @@ -905,67 +916,74 @@ func TestConnectionTransportParameters(t *testing.T) { } func TestConnectionHandleMaxStreamsFrame(t *testing.T) { - mockCtrl := gomock.NewController(t) - connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC)) - tc.conn.handleTransportParameters(&wire.TransportParameters{}) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger) + tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC)) + tc.conn.handleTransportParameters(&wire.TransportParameters{}) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - uniStreamChan := make(chan error) - go func() { - _, err := tc.conn.OpenUniStreamSync(ctx) - uniStreamChan <- err - }() - bidiStreamChan := make(chan error) - go func() { - _, err := tc.conn.OpenStreamSync(ctx) - bidiStreamChan <- err - }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + uniStreamChan := make(chan error) + go func() { + _, err := tc.conn.OpenUniStreamSync(ctx) + uniStreamChan <- err + }() + bidiStreamChan := make(chan error) + go func() { + _, err := tc.conn.OpenStreamSync(ctx) + bidiStreamChan <- err + }() - select { - case <-uniStreamChan: - t.Fatal("uni stream should be blocked") - case <-bidiStreamChan: - t.Fatal("bidi stream should be blocked") - case <-time.After(scaleDuration(10 * time.Millisecond)): - } + synctest.Wait() + select { + case <-uniStreamChan: + t.Fatal("uni stream should be blocked") + case <-bidiStreamChan: + t.Fatal("bidi stream should be blocked") + default: + } - // MAX_STREAMS frame for bidirectional stream - _, err := tc.conn.handleFrame( - &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}, - protocol.Encryption1RTT, - protocol.ConnectionID{}, - monotime.Now(), - ) - require.NoError(t, err) - - select { - case <-uniStreamChan: - t.Fatal("uni stream should be blocked") - case <-time.After(scaleDuration(10 * time.Millisecond)): - } - select { - case err := <-bidiStreamChan: + // MAX_STREAMS frame for bidirectional stream + _, err := tc.conn.handleFrame( + &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}, + protocol.Encryption1RTT, + protocol.ConnectionID{}, + monotime.Now(), + ) require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } - // MAX_STREAMS frame for bidirectional stream - _, err = tc.conn.handleFrame( - &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10}, - protocol.Encryption1RTT, - protocol.ConnectionID{}, - monotime.Now(), - ) - require.NoError(t, err) - select { - case err := <-uniStreamChan: + synctest.Wait() + + select { + case <-uniStreamChan: + t.Fatal("uni stream should be blocked") + default: + } + select { + case err := <-bidiStreamChan: + require.NoError(t, err) + default: + t.Fatal("bidi stream should be unblocked") + } + + // MAX_STREAMS frame for bidirectional stream + _, err = tc.conn.handleFrame( + &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10}, + protocol.Encryption1RTT, + protocol.ConnectionID{}, + monotime.Now(), + ) require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } + + synctest.Wait() + select { + case err := <-uniStreamChan: + require.NoError(t, err) + default: + t.Fatal("timeout") + } + }) } func TestConnectionTransportParameterValidationFailureServer(t *testing.T) { @@ -1550,221 +1568,228 @@ func TestConnectionPacketBuffering(t *testing.T) { } func TestConnectionPacketPacing(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sender := NewMockSender(mockCtrl) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sender := NewMockSender(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptSentPacketHandler(sph), - connectionOptSender(sender), - connectionOptHandshakeConfirmed(), - // set a fixed RTT, so that the idle timeout doesn't interfere with this test - connectionOptRTT(10*time.Second), - ) - sender.EXPECT().Run() + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptSentPacketHandler(sph), + connectionOptSender(sender), + connectionOptHandshakeConfirmed(), + ) + sender.EXPECT().Run() - step := scaleDuration(50 * time.Millisecond) + const step = 50 * time.Millisecond - sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() - gomock.InOrder( - // 1. allow 2 packets to be sent - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), - // 2. become pacing limited for 25ms - sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), - // 3. send another packet - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), - // 4. become pacing limited for 25ms... - sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), - // ... but this time we're still pacing limited when waking up. - // In this case, we can only send an ACK. - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), - // 5. stop the test by becoming pacing limited forever - sph.EXPECT().TimeUntilSend().Return(monotime.Now().Add(time.Hour)), - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - ) - sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() - for i := 0; i < 3; i++ { - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn( - func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { - buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil + sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() + gomock.InOrder( + // 1. allow 2 packets to be sent + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), + // 2. become pacing limited for 25ms + sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), + // 3. send another packet + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), + // 4. become pacing limited for 25ms... + sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), + // ... but this time we're still pacing limited when waking up. + // In this case, we can only send an ACK. + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), + // 5. stop the test by becoming pacing limited forever + sph.EXPECT().TimeUntilSend().Return(monotime.Now().Add(time.Hour)), + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + ) + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + for i := range 3 { + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn( + func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { + buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil + }, + ) + } + tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + buf := getPacketBuffer() + buf.Data = []byte("ack") + return shortHeaderPacket{PacketNumber: 1}, buf, nil }, ) - } - tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { - buf := getPacketBuffer() - buf.Data = []byte("ack") - return shortHeaderPacket{PacketNumber: 1}, buf, nil - }, - ) - sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().WouldBlock().AnyTimes() - type sentPacket struct { - time monotime.Time - data []byte - } - sendChan := make(chan sentPacket, 10) - sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { - sendChan <- sentPacket{time: monotime.Now(), data: b.Data} - }).Times(4) + type sentPacket struct { + time monotime.Time + data []byte + } + sendChan := make(chan sentPacket, 10) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { + sendChan <- sentPacket{time: monotime.Now(), data: b.Data} + }).Times(4) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - var times []monotime.Time - for i := 0; i < 3; i++ { + synctest.Wait() + + var times []monotime.Time + for i := range 3 { + select { + case b := <-sendChan: + require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data) + times = append(times, b.time) + case <-time.After(time.Hour): + t.Fatal("should have sent a packet") + } + } select { case b := <-sendChan: - require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data) + require.Equal(t, []byte("ack"), b.data) times = append(times, b.time) - case <-time.After(scaleDuration(time.Second)): + case <-time.After(time.Second): t.Fatal("timeout") } - } - select { - case b := <-sendChan: - require.Equal(t, []byte("ack"), b.data) - times = append(times, b.time) - case <-time.After(scaleDuration(time.Second)): - t.Fatal("timeout") - } - require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds()) - require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds()) - require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds()) + require.Equal(t, times[0], times[1]) + require.Equal(t, times[2], times[1].Add(step)) + require.Equal(t, times[3], times[2].Add(step)) - time.Sleep(scaleDuration(step)) // make sure that no more packets are sent - require.True(t, mockCtrl.Satisfied()) + synctest.Wait() // make sure that no more packets are sent + require.True(t, mockCtrl.Satisfied()) - // test teardown - sender.EXPECT().Close() - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case <-sendChan: - t.Fatal("should not have sent any more packets") - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + // test teardown + sender.EXPECT().Close() + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case <-sendChan: + t.Fatal("should not have sent any more packets") + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + }) } // When the send queue blocks, we need to reset the pacing timer, otherwise the run loop might busy-loop. // See https://github.com/quic-go/quic-go/pull/4943 for more details. func TestConnectionPacingAndSendQueue(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sender := NewMockSender(mockCtrl) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sender := NewMockSender(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptSentPacketHandler(sph), - connectionOptSender(sender), - connectionOptHandshakeConfirmed(), - // set a fixed RTT, so that the idle timeout doesn't interfere with this test - connectionOptRTT(10*time.Second), - ) - sender.EXPECT().Run() + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptSentPacketHandler(sph), + connectionOptSender(sender), + connectionOptHandshakeConfirmed(), + ) + sender.EXPECT().Run() - sendQueueAvailable := make(chan struct{}) - pacingDeadline := monotime.Now().Add(-time.Millisecond) - var counter int - // allow exactly one packet to be sent, then become blocked - sender.EXPECT().WouldBlock().Return(false) - sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes() - sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes() - sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes() - tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return( - shortHeaderPacket{}, nil, errNothingToPack, - ) + sendQueueAvailable := make(chan struct{}) + pacingDeadline := monotime.Now().Add(-time.Millisecond) + var counter int + // allow exactly one packet to be sent, then become blocked + sender.EXPECT().WouldBlock().Return(false) + sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes() + sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes() + sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes() + tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return( + shortHeaderPacket{}, nil, errNothingToPack, + ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - time.Sleep(scaleDuration(10 * time.Millisecond)) + synctest.Wait() - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - sender.EXPECT().Close() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + sender.EXPECT().Close() + tc.conn.destroy(nil) - // make sure the run loop didn't do too many iterations - require.Less(t, counter, 3) + synctest.Wait() + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + + // make sure the run loop didn't do too many iterations + require.Less(t, counter, 3) + }) } func TestConnectionIdleTimeout(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - &Config{MaxIdleTimeout: time.Second}, - false, - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - connectionOptRTT(time.Millisecond), - ) - // the idle timeout is set when the transport parameters are received - idleTimeout := scaleDuration(50 * time.Millisecond) - require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ - MaxIdleTimeout: idleTimeout, - })) - - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() - var lastSendTime monotime.Time - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { - buf.Data = append(buf.Data, []byte("foobar")...) - lastSendTime = monotime.Now() - return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil - }, - ) - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) - tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() - - select { - case err := <-errChan: - require.ErrorIs(t, err, &IdleTimeoutError{}) - require.NotZero(t, lastSendTime) - require.InDelta(t, - monotime.Since(lastSendTime).Seconds(), - idleTimeout.Seconds(), - scaleDuration(10*time.Millisecond).Seconds(), + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + &Config{MaxIdleTimeout: time.Minute}, + false, + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), + connectionOptRTT(time.Millisecond), ) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + // the idle timeout is set when the transport parameters are received + const idleTimeout = 500 * time.Millisecond + require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ + MaxIdleTimeout: idleTimeout, + })) + + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + var lastSendTime monotime.Time + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { + buf.Data = append(buf.Data, []byte("foobar")...) + lastSendTime = monotime.Now() + return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil + }, + ) + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() + + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, &IdleTimeoutError{}) + require.NotZero(t, lastSendTime) + require.Equal(t, idleTimeout, monotime.Since(lastSendTime)) + case <-time.After(time.Hour): + t.Fatal("should have timed out") + } + }) } func TestConnectionKeepAlive(t *testing.T) { @@ -1778,381 +1803,407 @@ func TestConnectionKeepAlive(t *testing.T) { } func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) { - var keepAlivePeriod time.Duration - if enable { - keepAlivePeriod = time.Second - } + synctest.Test(t, func(t *testing.T) { + var keepAlivePeriod time.Duration + if enable { + keepAlivePeriod = time.Second + } - mockCtrl := gomock.NewController(t) - unpacker := NewMockUnpacker(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - &Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod}, - false, - connectionOptUnpacker(unpacker), - connectionOptHandshakeConfirmed(), - connectionOptRTT(time.Millisecond), - ) - // the idle timeout is set when the transport parameters are received - idleTimeout := scaleDuration(50 * time.Millisecond) - require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ - MaxIdleTimeout: idleTimeout, - })) + mockCtrl := gomock.NewController(t) + unpacker := NewMockUnpacker(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + &Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod}, + false, + connectionOptUnpacker(unpacker), + connectionOptHandshakeConfirmed(), + connectionOptRTT(time.Millisecond), + ) + // the idle timeout is set when the transport parameters are received + const idleTimeout = 50 * time.Millisecond + require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ + MaxIdleTimeout: idleTimeout, + })) - // Receive a packet. This starts the keep-alive timer. - buf := getPacketBuffer() - var err error - buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero) - require.NoError(t, err) - buf.Data = append(buf.Data, []byte("packet")...) + // Receive a packet. This starts the keep-alive timer. + buf := getPacketBuffer() + var err error + buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero) + require.NoError(t, err) + buf.Data = append(buf.Data, []byte("packet")...) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() - var unpackTime, packTime monotime.Time - done := make(chan struct{}) - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn( - func(t monotime.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { - unpackTime = monotime.Now() - return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil - }, - ) - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) - - switch expectKeepAlive { - case true: - // record the time of the keep-alive is sent - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - packTime = monotime.Now() - close(done) - return shortHeaderPacket{}, errNothingToPack + var unpackTime, packTime monotime.Time + done := make(chan struct{}) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn( + func(t monotime.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { + unpackTime = monotime.Now() + return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil }, ) - tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) - select { - case <-done: - // the keep-alive packet should be sent after half the idle timeout - diff := packTime.Sub(unpackTime) - require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds()) - case <-time.After(idleTimeout): - t.Fatal("timeout") - } - case false: // if keep-alives are disabled, the connection will run into an idle timeout - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case <-time.After(idleTimeout): - } - } + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) - // test teardown - if expectKeepAlive { - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - } - select { - case err := <-errChan: - if expectKeepAlive { - require.NoError(t, err) - } else { - require.ErrorIs(t, err, &IdleTimeoutError{}) + switch expectKeepAlive { + case true: + // record the time of the keep-alive is sent + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + packTime = monotime.Now() + close(done) + return shortHeaderPacket{}, errNothingToPack + }, + ) + tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) + select { + case <-done: + // the keep-alive packet should be sent after half the idle timeout + require.Equal(t, unpackTime.Add(idleTimeout/2), packTime) + case <-time.After(idleTimeout): + t.Fatal("timeout") + } + case false: // if keep-alives are disabled, the connection will run into an idle timeout + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) } - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + + // test teardown + if expectKeepAlive { + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + } + + synctest.Wait() + + select { + case err := <-errChan: + if expectKeepAlive { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, &IdleTimeoutError{}) + } + case <-time.After(time.Hour): + t.Fatal("timeout") + } + }) } func TestConnectionACKTimer(t *testing.T) { - mockCtrl := gomock.NewController(t) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - &Config{MaxIdleTimeout: time.Second}, - false, - connectionOptHandshakeConfirmed(), - connectionOptReceivedPacketHandler(rph), - connectionOptSentPacketHandler(sph), - connectionOptRTT(10*time.Second), - ) - alarmTimeout := scaleDuration(50 * time.Millisecond) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + &Config{MaxIdleTimeout: time.Second}, + false, + connectionOptHandshakeConfirmed(), + connectionOptReceivedPacketHandler(rph), + connectionOptSentPacketHandler(sph), + ) + const alarmTimeout = 500 * time.Millisecond - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() - rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour)) - tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour)) + tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - var times []monotime.Time - done := make(chan struct{}, 5) - var calls []any - for i := 0; i < 2; i++ { - calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { - buf.Data = append(buf.Data, []byte("foobar")...) - times = append(times, monotime.Now()) - return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil - }, - )) - calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { - done <- struct{}{} - return shortHeaderPacket{}, errNothingToPack - }, - )) - if i == 0 { - calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(alarmTimeout))) - } else { - calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour)).MaxTimes(1)) + var times []monotime.Time + done := make(chan struct{}, 5) + var calls []any + for i := range 2 { + calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { + buf.Data = append(buf.Data, []byte("foobar")...) + times = append(times, monotime.Now()) + return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil + }, + )) + calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) { + done <- struct{}{} + return shortHeaderPacket{}, errNothingToPack + }, + )) + if i == 0 { + calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(alarmTimeout))) + } else { + calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour))) + } } - } - gomock.InOrder(calls...) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + gomock.InOrder(calls...) + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - for i := 0; i < 2; i++ { + for range 2 { + synctest.Wait() + + select { + case <-done: + case <-time.After(time.Hour): + t.Fatal("timeout") + } + } + + assert.Len(t, times, 2) + require.Equal(t, times[0].Add(alarmTimeout), times[1]) + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("timeout") + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") } - } - - assert.Len(t, times, 2) - require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds()) - - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + }) } // Send a GSO batch, until we have no more data to send. func TestConnectionGSOBatch(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - true, - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - ) - - // allow packets to be sent - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() - - maxPacketSize := tc.conn.maxPacketSize() - var expectedData []byte - for i := 0; i < 4; i++ { - data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) - expectedData = append(expectedData, data...) - - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, data...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil - }, + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + true, + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), ) - } - done := make(chan struct{}) - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) - tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn( - func([]byte, uint16, protocol.ECN) error { close(done); return nil }, - ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + // allow packets to be sent + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("timeout") - } + maxPacketSize := tc.conn.maxPacketSize() + var expectedData []byte + for i := range 4 { + data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) + expectedData = append(expectedData, data...) - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, data...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil + }, + ) + } + done := make(chan struct{}) + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn( + func([]byte, uint16, protocol.ECN) error { close(done); return nil }, + ) + + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() + + synctest.Wait() + + select { + case <-done: + default: + t.Fatal("should have sent a packet") + } + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + }) } // Send a GSO batch, until a packet smaller than the maximum size is packed func TestConnectionGSOBatchPacketSize(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - true, - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + true, + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), + ) - // allow packets to be sent - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() + // allow packets to be sent + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() - maxPacketSize := tc.conn.maxPacketSize() - var expectedData []byte - var calls []any - for i := 0; i < 4; i++ { - var data []byte - if i == 3 { - data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1)) - } else { - data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) + maxPacketSize := tc.conn.maxPacketSize() + var expectedData []byte + var calls []any + for i := range 4 { + var data []byte + if i == 3 { + data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1)) + } else { + data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) + } + expectedData = append(expectedData, data...) + + calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, data...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil + }, + )) } - expectedData = append(expectedData, data...) + // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. + // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. + calls = append(calls, + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, []byte("foobar")...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil + }, + ), + ) + calls = append(calls, + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), + ) + gomock.InOrder(calls...) - calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, data...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil - }, - )) - } - // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. - // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. - calls = append(calls, - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil - }, - ), - ) - calls = append(calls, - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), - ) - gomock.InOrder(calls...) + done := make(chan struct{}) + gomock.InOrder( + tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1), + tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn( + func([]byte, uint16, protocol.ECN) error { close(done); return nil }, + ), + ) + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - done := make(chan struct{}) - gomock.InOrder( - tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1), - tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn( - func([]byte, uint16, protocol.ECN) error { close(done); return nil }, - ), - ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Wait() - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + select { + case <-done: + default: + t.Fatal("should have sent a packet") + } + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + }) } func TestConnectionGSOBatchECN(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - true, - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + true, + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), + ) - // allow packets to be sent - ecnMode := protocol.ECT1 - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes() + // allow packets to be sent + ecnMode := protocol.ECT1 + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes() - // 3. Send a GSO batch, until the ECN marking changes. - var expectedData []byte - var calls []any - maxPacketSize := tc.conn.maxPacketSize() - for i := 0; i < 3; i++ { - data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) - expectedData = append(expectedData, data...) + // 3. Send a GSO batch, until the ECN marking changes. + var expectedData []byte + var calls []any + maxPacketSize := tc.conn.maxPacketSize() + for i := range 3 { + data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) + expectedData = append(expectedData, data...) - calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, data...) - if i == 2 { - ecnMode = protocol.ECNCE - } - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil - }, - )) - } - // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. - // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. - calls = append(calls, - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil - }, - ), - ) - calls = append(calls, - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), - ) - gomock.InOrder(calls...) + calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, data...) + if i == 2 { + ecnMode = protocol.ECNCE + } + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil + }, + )) + } + // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. + // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. + calls = append(calls, + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, []byte("foobar")...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil + }, + ), + ) + calls = append(calls, + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), + ) + gomock.InOrder(calls...) - done3 := make(chan struct{}) - tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1) - tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn( - func([]byte, uint16, protocol.ECN) error { close(done3); return nil }, - ) + done3 := make(chan struct{}) + tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1) + tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn( + func([]byte, uint16, protocol.ECN) error { close(done3); return nil }, + ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - select { - case <-done3: - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Wait() - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + select { + case <-done3: + default: + t.Fatal("should have sent a packet") + } + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + }) } func TestConnectionPTOProbePackets(t *testing.T) { @@ -2168,141 +2219,153 @@ func TestConnectionPTOProbePackets(t *testing.T) { } func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptSentPacketHandler(sph), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptSentPacketHandler(sph), + ) - var sendMode ackhandler.SendMode - switch encLevel { - case protocol.EncryptionInitial: - sendMode = ackhandler.SendPTOInitial - case protocol.EncryptionHandshake: - sendMode = ackhandler.SendPTOHandshake - case protocol.Encryption1RTT: - sendMode = ackhandler.SendPTOAppData - } + var sendMode ackhandler.SendMode + switch encLevel { + case protocol.EncryptionInitial: + sendMode = ackhandler.SendPTOInitial + case protocol.EncryptionHandshake: + sendMode = ackhandler.SendPTOHandshake + case protocol.Encryption1RTT: + sendMode = ackhandler.SendPTOAppData + } - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) - sph.EXPECT().ECNMode(gomock.Any()) - sph.EXPECT().QueueProbePacket(encLevel).Return(false) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + sph.EXPECT().ECNMode(gomock.Any()) + sph.EXPECT().QueueProbePacket(encLevel).Return(false) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tc.packer.EXPECT().PackPTOProbePacket(encLevel, gomock.Any(), true, gomock.Any(), protocol.Version1).DoAndReturn( - func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error) { - return &coalescedPacket{ - buffer: getPacketBuffer(), - shortHdrPacket: &shortHeaderPacket{PacketNumber: 1}, - }, nil - }, - ) - done := make(chan struct{}) - tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( - func([]byte, uint16, protocol.ECN) error { close(done); return nil }, - ) + tc.packer.EXPECT().PackPTOProbePacket(encLevel, gomock.Any(), true, gomock.Any(), protocol.Version1).DoAndReturn( + func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error) { + return &coalescedPacket{ + buffer: getPacketBuffer(), + shortHdrPacket: &shortHeaderPacket{PacketNumber: 1}, + }, nil + }, + ) + done := make(chan struct{}) + tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func([]byte, uint16, protocol.ECN) error { close(done); return nil }, + ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("timeout") - } + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have timed out") + } + }) } func TestConnectionCongestionControl(t *testing.T) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - connectionOptRTT(10*time.Second), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), + ) - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().ECNMode(true).AnyTimes() - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1) - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) - // Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket - for i := 0; i < 2; i++ { - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { - buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().ECNMode(true).AnyTimes() + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1) + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + // Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket + for i := range 2 { + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { + buffer.Data = append(buffer.Data, []byte("foobar")...) + return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil + }, + ) + } + tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) + done1 := make(chan struct{}) + tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func([]byte, uint16, protocol.ECN) error { close(done1); return nil }, + ) + + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() + + synctest.Wait() + + select { + case <-done1: + default: + t.Fatal("should have sent a packet") + } + require.True(t, mockCtrl.Satisfied()) + + // Now that we're congestion limited, we can only send an ack-only packet + done2 := make(chan struct{}) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) + tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + close(done2) + return shortHeaderPacket{}, nil, errNothingToPack }, ) - } - tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - done1 := make(chan struct{}) - tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( - func([]byte, uint16, protocol.ECN) error { close(done1); return nil }, - ) + tc.conn.scheduleSending() - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() - select { - case <-done1: - case <-time.After(time.Second): - t.Fatal("timeout") - } - require.True(t, mockCtrl.Satisfied()) + synctest.Wait() - // Now that we're congestion limited, we can only send an ack-only packet - done2 := make(chan struct{}) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) - tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) { - close(done2) - return shortHeaderPacket{}, nil, errNothingToPack - }, - ) - tc.conn.scheduleSending() - select { - case <-done2: - case <-time.After(time.Second): - t.Fatal("timeout") - } - require.True(t, mockCtrl.Satisfied()) + select { + case <-done2: + default: + t.Fatal("should have sent an ack-only packet") + } + require.True(t, mockCtrl.Satisfied()) - // If the send mode is "none", we can't even send an ack-only packet - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) - tc.conn.scheduleSending() - time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer + // If the send mode is "none", we can't even send an ack-only packet + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) + tc.conn.scheduleSending() + synctest.Wait() // make sure there are no calls to the packer - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("timeout") + } + }) } func TestConnectionSendQueue(t *testing.T) { @@ -2315,75 +2378,85 @@ func TestConnectionSendQueue(t *testing.T) { } func testConnectionSendQueue(t *testing.T, enableGSO bool) { - mockCtrl := gomock.NewController(t) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sender := NewMockSender(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - enableGSO, - connectionOptSender(sender), - connectionOptHandshakeConfirmed(), - connectionOptSentPacketHandler(sph), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sender := NewMockSender(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + enableGSO, + connectionOptSender(sender), + connectionOptHandshakeConfirmed(), + connectionOptSentPacketHandler(sph), + ) - sender.EXPECT().Run().MaxTimes(1) - sender.EXPECT().WouldBlock() - sender.EXPECT().WouldBlock().Return(true).Times(2) - available := make(chan struct{}) - blocked := make(chan struct{}) - sender.EXPECT().Available().DoAndReturn( - func() <-chan struct{} { - close(blocked) - return available - }, - ) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( - shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil, - ) - sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) + sender.EXPECT().Run().MaxTimes(1) + sender.EXPECT().WouldBlock() + sender.EXPECT().WouldBlock().Return(true).Times(2) + available := make(chan struct{}) + blocked := make(chan struct{}) + sender.EXPECT().Available().DoAndReturn( + func() <-chan struct{} { + close(blocked) + return available + }, + ) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil, + ) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.scheduleSending() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.scheduleSending() - select { - case <-blocked: - case <-time.After(time.Second): - t.Fatal("timeout") - } - require.True(t, mockCtrl.Satisfied()) + synctest.Wait() - // now make room in the send queue - sender.EXPECT().WouldBlock().AnyTimes() - unblocked := make(chan struct{}) - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) { - close(unblocked) - return shortHeaderPacket{}, errNothingToPack - }, - ) - available <- struct{}{} - select { - case <-unblocked: - case <-time.After(time.Second): - t.Fatal("timeout") - } + select { + case <-blocked: + default: + t.Fatal("should have blocked") + } + require.True(t, mockCtrl.Satisfied()) - // test teardown - sender.EXPECT().Close() - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(3 * time.Second): - t.Fatal("timeout") - } + // now make room in the send queue + sender.EXPECT().WouldBlock().AnyTimes() + unblocked := make(chan struct{}) + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) { + close(unblocked) + return shortHeaderPacket{}, errNothingToPack + }, + ) + available <- struct{}{} + + synctest.Wait() + + select { + case <-unblocked: + default: + t.Fatal("should have unblocked") + } + + // test teardown + sender.EXPECT().Close() + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("timeout") + } + }) } func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket { @@ -2400,81 +2473,89 @@ func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []pro } func TestConnectionVersionNegotiation(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newClientTestConnection(t, - mockCtrl, - nil, - false, - connectionOptTracer(tr), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newClientTestConnection(t, + mockCtrl, + nil, + false, + connectionOptTracer(tr), + ) - tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - var tracerVersions []logging.Version - gomock.InOrder( - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { - tracerVersions = versions - }), - tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()), - tc.connRunner.EXPECT().Remove(gomock.Any()), - ) + tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + var tracerVersions []logging.Version + gomock.InOrder( + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { + tracerVersions = versions + }), + tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()), + tc.connRunner.EXPECT().Remove(gomock.Any()), + ) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.handlePacket(getVersionNegotiationPacket( - tc.destConnID, - tc.srcConnID, - []protocol.Version{1234, protocol.Version2}, - )) + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.handlePacket(getVersionNegotiationPacket( + tc.destConnID, + tc.srcConnID, + []protocol.Version{1234, protocol.Version2}, + )) - select { - case err := <-errChan: - var rerr *errCloseForRecreating - require.ErrorAs(t, err, &rerr) - require.Equal(t, rerr.nextVersion, protocol.Version2) - case <-time.After(time.Second): - t.Fatal("timeout") - } - require.Contains(t, tracerVersions, protocol.Version(1234)) - require.Contains(t, tracerVersions, protocol.Version2) + synctest.Wait() + + select { + case err := <-errChan: + var rerr *errCloseForRecreating + require.ErrorAs(t, err, &rerr) + require.Equal(t, rerr.nextVersion, protocol.Version2) + default: + t.Fatal("should have received a Version Negotiation packet") + } + require.Contains(t, tracerVersions, protocol.Version(1234)) + require.Contains(t, tracerVersions, protocol.Version2) + }) } func TestConnectionVersionNegotiationNoMatch(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newClientTestConnection(t, - mockCtrl, - &Config{Versions: []protocol.Version{protocol.Version1}}, - false, - connectionOptTracer(tr), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newClientTestConnection(t, + mockCtrl, + &Config{Versions: []protocol.Version{protocol.Version1}}, + false, + connectionOptTracer(tr), + ) - tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - var tracerVersions []logging.Version - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do( - func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions }, - ) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - tc.connRunner.EXPECT().Remove(gomock.Any()) + tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + var tracerVersions []logging.Version + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions }, + ) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + tc.connRunner.EXPECT().Remove(gomock.Any()) - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.handlePacket(getVersionNegotiationPacket( - tc.destConnID, - tc.srcConnID, - []protocol.Version{protocol.Version2}, - )) + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + tc.conn.handlePacket(getVersionNegotiationPacket( + tc.destConnID, + tc.srcConnID, + []protocol.Version{protocol.Version2}, + )) - select { - case err := <-errChan: - var verr *VersionNegotiationError - require.ErrorAs(t, err, &verr) - require.Contains(t, verr.Theirs, protocol.Version2) - case <-time.After(time.Second): - t.Fatal("timeout") - } - require.Contains(t, tracerVersions, protocol.Version2) + synctest.Wait() + + select { + case err := <-errChan: + var verr *VersionNegotiationError + require.ErrorAs(t, err, &verr) + require.Contains(t, verr.Theirs, protocol.Version2) + default: + t.Fatal("should have received a Version Negotiation packet") + } + require.Contains(t, tracerVersions, protocol.Version2) + }) } func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) { @@ -2606,122 +2687,135 @@ func TestConnectionConnectionIDChanges(t *testing.T) { } func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) { - makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte { - t.Helper() - data, err := hdr.Append(nil, protocol.Version1) - require.NoError(t, err) - data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...) - return data - } - - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - unpacker := NewMockUnpacker(mockCtrl) - tc := newClientTestConnection(t, - mockCtrl, - nil, - false, - connectionOptTracer(tr), - connectionOptUnpacker(unpacker), - ) - - dstConnID := tc.destConnID - b := make([]byte, 3*10) - rand.Read(b) - newConnID := protocol.ParseConnectionID(b[:11]) - newConnID2 := protocol.ParseConnectionID(b[11:20]) - - tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()) - tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - - require.Equal(t, dstConnID, tc.conn.connIDManager.Get()) - - var retryConnID protocol.ConnectionID - if sendRetry { - retryConnID = protocol.ParseConnectionID(b[20:30]) - hdrChan := make(chan *wire.Header) - tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr }) - tc.packer.EXPECT().SetToken([]byte("foobar")) - - tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))) - select { - case hdr := <-hdrChan: - assert.Equal(t, retryConnID, hdr.SrcConnectionID) - assert.Equal(t, []byte("foobar"), hdr.Token) - require.Equal(t, retryConnID, tc.conn.connIDManager.Get()) - case <-time.After(time.Second): - t.Fatal("timeout") + synctest.Test(t, func(t *testing.T) { + makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte { + t.Helper() + data, err := hdr.Append(nil, protocol.Version1) + require.NoError(t, err) + data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...) + return data } - } - // Send the first packet. The server changes the connection ID to newConnID. - hdr1 := wire.ExtendedHeader{ - Header: wire.Header{ - SrcConnectionID: newConnID, - DestConnectionID: tc.srcConnID, - Type: protocol.PacketTypeInitial, - Length: 200, - Version: protocol.Version1, - }, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr2 := hdr1 - hdr2.SrcConnectionID = newConnID2 + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + unpacker := NewMockUnpacker(mockCtrl) + tc := newClientTestConnection(t, + mockCtrl, + nil, + false, + connectionOptTracer(tr), + connectionOptUnpacker(unpacker), + ) - receivedFirst := make(chan struct{}) - gomock.InOrder( - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( - &unpackedPacket{ - hdr: &hdr1, - encryptionLevel: protocol.EncryptionInitial, - }, nil, - ), - tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do( - func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) }, - ), - ) + dstConnID := tc.destConnID + b := make([]byte, 3*10) + rand.Read(b) + newConnID := protocol.ParseConnectionID(b[:11]) + newConnID2 := protocol.ParseConnectionID(b[11:20]) - tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) + tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()) + tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - select { - case <-receivedFirst: - require.Equal(t, newConnID, tc.conn.connIDManager.Get()) - case <-time.After(time.Second): - t.Fatal("timeout") - } + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() - // Send the second packet. We refuse to accept it, because the connection ID is changed again. - dropped := make(chan struct{}) - tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do( - func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) { - close(dropped) - }, - ) + require.Equal(t, dstConnID, tc.conn.connIDManager.Get()) - tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) - select { - case <-dropped: - // the connection ID should not have changed - require.Equal(t, newConnID, tc.conn.connIDManager.Get()) - case <-time.After(time.Second): - t.Fatal("timeout") - } + var retryConnID protocol.ConnectionID + if sendRetry { + retryConnID = protocol.ParseConnectionID(b[20:30]) + hdrChan := make(chan *wire.Header) + tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr }) + tc.packer.EXPECT().SetToken([]byte("foobar")) - // test teardown - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - tc.connRunner.EXPECT().Remove(gomock.Any()) - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } + tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))) + + synctest.Wait() + + select { + case hdr := <-hdrChan: + assert.Equal(t, retryConnID, hdr.SrcConnectionID) + assert.Equal(t, []byte("foobar"), hdr.Token) + require.Equal(t, retryConnID, tc.conn.connIDManager.Get()) + default: + t.Fatal("should have received the retry packet") + } + } + + // Send the first packet. The server changes the connection ID to newConnID. + hdr1 := wire.ExtendedHeader{ + Header: wire.Header{ + SrcConnectionID: newConnID, + DestConnectionID: tc.srcConnID, + Type: protocol.PacketTypeInitial, + Length: 200, + Version: protocol.Version1, + }, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr2 := hdr1 + hdr2.SrcConnectionID = newConnID2 + + receivedFirst := make(chan struct{}) + gomock.InOrder( + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( + &unpackedPacket{ + hdr: &hdr1, + encryptionLevel: protocol.EncryptionInitial, + }, nil, + ), + tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) }, + ), + ) + + tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) + + synctest.Wait() + + select { + case <-receivedFirst: + require.Equal(t, newConnID, tc.conn.connIDManager.Get()) + default: + t.Fatal("should have received the first packet") + } + + // Send the second packet. We refuse to accept it, because the connection ID is changed again. + dropped := make(chan struct{}) + tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do( + func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) { + close(dropped) + }, + ) + + tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) + + synctest.Wait() + + select { + case <-dropped: + // the connection ID should not have changed + require.Equal(t, newConnID, tc.conn.connIDManager.Get()) + default: + t.Fatal("should have dropped the packet") + } + + // test teardown + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + tc.connRunner.EXPECT().Remove(gomock.Any()) + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have shut down") + } + }) } // When the connection is closed before sending the first packet, @@ -2729,38 +2823,42 @@ func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) { // This can happen if there's something wrong the tls.Config, and // crypto/tls refuses to start the handshake. func TestConnectionEarlyClose(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl) - tc := newClientTestConnection(t, - mockCtrl, - nil, - false, - connectionOptTracer(tr), - connectionOptCryptoSetup(cryptoSetup), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl) + tc := newClientTestConnection(t, + mockCtrl, + nil, + false, + connectionOptTracer(tr), + connectionOptCryptoSetup(cryptoSetup), + ) - tc.conn.sentFirstPacket = false - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error { - tc.conn.closeLocal(errors.New("early error")) - return nil + tc.conn.sentFirstPacket = false + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error { + tc.conn.closeLocal(errors.New("early error")) + return nil + }) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) + cryptoSetup.EXPECT().Close() + tc.connRunner.EXPECT().Remove(gomock.Any()) + + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + + synctest.Wait() + + select { + case err := <-errChan: + require.Error(t, err) + require.ErrorContains(t, err, "early error") + default: + t.Fatal("should have shut down") + } }) - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) - cryptoSetup.EXPECT().Close() - tc.connRunner.EXPECT().Remove(gomock.Any()) - - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - - select { - case err := <-errChan: - require.Error(t, err) - require.ErrorContains(t, err, "early error") - case <-time.After(time.Second): - t.Fatal("timeout") - } } func TestConnectionPathValidation(t *testing.T) { @@ -2774,144 +2872,156 @@ func TestConnectionPathValidation(t *testing.T) { } func testConnectionPathValidation(t *testing.T, isNATRebinding bool) { - mockCtrl := gomock.NewController(t) - unpacker := NewMockUnpacker(mockCtrl) - tc := newServerTestConnection( - t, - mockCtrl, - nil, - false, - connectionOptUnpacker(unpacker), - connectionOptHandshakeConfirmed(), - connectionOptRTT(time.Second), - ) - require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456})) - - newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234} - require.NotEqual(t, tc.remoteAddr, newRemoteAddr) - - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - - probeSent := make(chan struct{}) - var pathChallenge *wire.PathChallengeFrame - payload := []byte{0} // PADDING frame - if isNATRebinding { - payload = []byte{1} // PING frame - } - gomock.InOrder( - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( - protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, - ), - tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { - pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame) - return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil - }, - ), - tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn( - func([]byte, net.Addr) error { close(probeSent); return nil }, - ), - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( - shortHeaderPacket{}, errNothingToPack, - ), - ) - - tc.conn.handlePacket(receivedPacket{ - data: make([]byte, 10), - buffer: getPacketBuffer(), - remoteAddr: newRemoteAddr, - rcvTime: monotime.Now(), - }) - - select { - case <-probeSent: - case <-time.After(time.Second): - t.Fatal("timeout") - } - - // Receive a packed containing a PATH_RESPONSE frame. - // Only if the first packet received on the path was a probing packet - // (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path. - migrated := make(chan struct{}) - data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1) - require.NoError(t, err) - calls := []any{ - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( - protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil, - ), - } - if isNATRebinding { - calls = append(calls, - tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( - func(net.Addr, packetInfo) { close(migrated) }, - ), + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + unpacker := NewMockUnpacker(mockCtrl) + tc := newServerTestConnection( + t, + mockCtrl, + nil, + false, + connectionOptUnpacker(unpacker), + connectionOptHandshakeConfirmed(), + connectionOptRTT(time.Second), ) - } - calls = append(calls, - tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( - shortHeaderPacket{}, errNothingToPack, - ).MaxTimes(1), - ) - gomock.InOrder(calls...) - require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr()) - // the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path - addr := tc.remoteAddr - if isNATRebinding { - addr = newRemoteAddr - } - tc.conn.handlePacket(receivedPacket{ - data: make([]byte, 100), - buffer: getPacketBuffer(), - remoteAddr: addr, - rcvTime: monotime.Now(), - }) + require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456})) - if !isNATRebinding { - // If the first packet was a probing packet, we only switch to the new path when we - // receive a non-probing packet on that path. - select { - case <-migrated: - t.Fatal("didn't expect a migration yet") - case <-time.After(scaleDuration(10 * time.Millisecond)): + newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234} + require.NotEqual(t, tc.remoteAddr, newRemoteAddr) + + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() + + probeSent := make(chan struct{}) + var pathChallenge *wire.PathChallengeFrame + payload := []byte{0} // PADDING frame + if isNATRebinding { + payload = []byte{1} // PING frame } - - payload := []byte{1} // PING frame - payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1) - require.NoError(t, err) gomock.InOrder( unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( - protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, + protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, ), - tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( - func(net.Addr, packetInfo) { close(migrated) }, + tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame) + return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil + }, + ), + tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn( + func([]byte, net.Addr) error { close(probeSent); return nil }, ), tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, errNothingToPack, - ).MaxTimes(1), + ), ) + tc.conn.handlePacket(receivedPacket{ - data: make([]byte, 100), + data: make([]byte, 10), buffer: getPacketBuffer(), remoteAddr: newRemoteAddr, rcvTime: monotime.Now(), }) - } - select { - case <-migrated: - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Wait() - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - tc.conn.destroy(nil) - select { - case <-errChan: - case <-time.After(time.Second): - t.Fatal("timeout") - } + select { + case <-probeSent: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // Receive a packed containing a PATH_RESPONSE frame. + // Only if the first packet received on the path was a probing packet + // (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path. + migrated := make(chan struct{}) + data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1) + require.NoError(t, err) + calls := []any{ + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( + protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil, + ), + } + if isNATRebinding { + calls = append(calls, + tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( + func(net.Addr, packetInfo) { close(migrated) }, + ), + ) + } + calls = append(calls, + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + shortHeaderPacket{}, errNothingToPack, + ).MaxTimes(1), + ) + gomock.InOrder(calls...) + require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr()) + // the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path + addr := tc.remoteAddr + if isNATRebinding { + addr = newRemoteAddr + } + tc.conn.handlePacket(receivedPacket{ + data: make([]byte, 100), + buffer: getPacketBuffer(), + remoteAddr: addr, + rcvTime: monotime.Now(), + }) + + synctest.Wait() + + if !isNATRebinding { + // If the first packet was a probing packet, we only switch to the new path when we + // receive a non-probing packet on that path. + select { + case <-migrated: + t.Fatal("didn't expect a migration yet") + default: + } + + payload := []byte{1} // PING frame + payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1) + require.NoError(t, err) + gomock.InOrder( + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( + protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, + ), + tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( + func(net.Addr, packetInfo) { close(migrated) }, + ), + tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + shortHeaderPacket{}, errNothingToPack, + ).MaxTimes(1), + ) + tc.conn.handlePacket(receivedPacket{ + data: make([]byte, 100), + buffer: getPacketBuffer(), + remoteAddr: newRemoteAddr, + rcvTime: monotime.Now(), + }) + } + + synctest.Wait() + + select { + case <-migrated: + default: + t.Fatal("should have migrated") + } + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.conn.destroy(nil) + + synctest.Wait() + + select { + case err := <-errChan: + require.NoError(t, err) + default: + t.Fatal("should have shut down") + } + }) } func TestConnectionMigrationServer(t *testing.T) {