diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index b912b6e3..91048400 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -204,20 +204,18 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { addr net.Addr err error } - done := make(chan struct{}) - var rcvdPackets []nonQUICPacket - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + rcvdPackets := make(chan nonQUICPacket, 1024) + receiveCtx, receiveCancel := context.WithCancel(context.Background()) + defer receiveCancel() // start receiving non-QUIC packets go func() { - defer close(done) for { b := make([]byte, 1024) - n, addr, err := tr2.ReadNonQUICPacket(ctx, b) + n, addr, err := tr2.ReadNonQUICPacket(receiveCtx, b) if errors.Is(err, context.Canceled) { return } - rcvdPackets = append(rcvdPackets, nonQUICPacket{b: b[:n], addr: addr, err: err}) + rcvdPackets <- nonQUICPacket{b: b[:n], addr: addr, err: err} } }() @@ -225,7 +223,10 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { const packetLen = 128 var sentPackets atomic.Int64 errChan := make(chan error, 1) + done := make(chan struct{}) + stopSending := make(chan struct{}) go func() { + defer close(done) ticker := time.NewTicker(time.Millisecond / 10) defer ticker.Stop() @@ -248,22 +249,19 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { } sentPackets.Add(1) wroteFirstPacket = true - case <-ctx.Done(): + case <-stopSending: return } } }() - conn, err := tr2.Dial( - context.Background(), - server.Addr(), - getTLSClientConfig(), - getQuicConfig(nil), - ) + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + conn, err := tr2.Dial(ctx2, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") - serverConn, err := server.Accept(context.Background()) + serverConn, err := server.Accept(ctx2) require.NoError(t, err) serverStr, err := serverConn.OpenUniStream() require.NoError(t, err) @@ -272,14 +270,14 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { _, _ = serverStr.Write(PRData) }() - str, err := conn.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(ctx2) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) // stop sending non-QUIC packets - cancel() + close(stopSending) select { case err := <-errChan: @@ -289,13 +287,25 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { sent := int(sentPackets.Load()) require.Greater(t, sent, 10, "not enough non-QUIC packets sent: %d", sent) - rcvd := len(rcvdPackets) - minExpected := sent * 4 / 5 - require.GreaterOrEqual(t, rcvd, minExpected, "not enough packets received. got: %d, expected at least: %d", rcvd, minExpected) + // make sure we don't overflow the capacity of the channel + require.LessOrEqual(t, sent, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", sent, cap(rcvdPackets)) - for _, p := range rcvdPackets { - require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address") - require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length") - require.NoError(t, p.err, "error receiving non-QUIC packet") + // now receive these packets + minExpected := sent * 4 / 5 + timeout := time.After(time.Second) + var counter int + for { + if counter >= minExpected { + break + } + select { + case p := <-rcvdPackets: + require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address") + require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length") + require.NoError(t, p.err, "error receiving non-QUIC packet") + counter++ + case <-timeout: + t.Fatalf("didn't receive enough non-QUIC packets: %d < %d", counter, minExpected) + } } }