forked from quic-go/quic-go
fix flaky multiplex integration test (#4760)
This commit is contained in:
@@ -204,20 +204,18 @@ func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
addr net.Addr
|
||||
err error
|
||||
}
|
||||
done := make(chan struct{})
|
||||
var rcvdPackets []nonQUICPacket
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
rcvdPackets := make(chan nonQUICPacket, 1024)
|
||||
receiveCtx, receiveCancel := context.WithCancel(context.Background())
|
||||
defer receiveCancel()
|
||||
// start receiving non-QUIC packets
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
|
||||
n, addr, err := tr2.ReadNonQUICPacket(receiveCtx, b)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
rcvdPackets = append(rcvdPackets, nonQUICPacket{b: b[:n], addr: addr, err: err})
|
||||
rcvdPackets <- nonQUICPacket{b: b[:n], addr: addr, err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -225,7 +223,10 @@ func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
const packetLen = 128
|
||||
var sentPackets atomic.Int64
|
||||
errChan := make(chan error, 1)
|
||||
done := make(chan struct{})
|
||||
stopSending := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
ticker := time.NewTicker(time.Millisecond / 10)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -248,22 +249,19 @@ func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
}
|
||||
sentPackets.Add(1)
|
||||
wroteFirstPacket = true
|
||||
case <-ctx.Done():
|
||||
case <-stopSending:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := tr2.Dial(
|
||||
context.Background(),
|
||||
server.Addr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel2()
|
||||
conn, err := tr2.Dial(ctx2, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
serverConn, err := server.Accept(ctx2)
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
@@ -272,14 +270,14 @@ func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
_, _ = serverStr.Write(PRData)
|
||||
}()
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
str, err := conn.AcceptUniStream(ctx2)
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
|
||||
// stop sending non-QUIC packets
|
||||
cancel()
|
||||
close(stopSending)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -289,13 +287,25 @@ func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
|
||||
sent := int(sentPackets.Load())
|
||||
require.Greater(t, sent, 10, "not enough non-QUIC packets sent: %d", sent)
|
||||
rcvd := len(rcvdPackets)
|
||||
minExpected := sent * 4 / 5
|
||||
require.GreaterOrEqual(t, rcvd, minExpected, "not enough packets received. got: %d, expected at least: %d", rcvd, minExpected)
|
||||
// make sure we don't overflow the capacity of the channel
|
||||
require.LessOrEqual(t, sent, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", sent, cap(rcvdPackets))
|
||||
|
||||
for _, p := range rcvdPackets {
|
||||
require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address")
|
||||
require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length")
|
||||
require.NoError(t, p.err, "error receiving non-QUIC packet")
|
||||
// now receive these packets
|
||||
minExpected := sent * 4 / 5
|
||||
timeout := time.After(time.Second)
|
||||
var counter int
|
||||
for {
|
||||
if counter >= minExpected {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case p := <-rcvdPackets:
|
||||
require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address")
|
||||
require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length")
|
||||
require.NoError(t, p.err, "error receiving non-QUIC packet")
|
||||
counter++
|
||||
case <-timeout:
|
||||
t.Fatalf("didn't receive enough non-QUIC packets: %d < %d", counter, minExpected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user