diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 91048400..2320706e 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -8,7 +8,6 @@ import ( "io" "net" "runtime" - "sync/atomic" "testing" "time" @@ -187,24 +186,25 @@ func TestMultiplexingServerAndClientOnSameConn(t *testing.T) { } func TestMultiplexingNonQUICPackets(t *testing.T) { + const numPackets = 100 + tr1 := &quic.Transport{Conn: newUPDConnLocalhost(t)} defer tr1.Close() addTracer(tr1) + server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) + require.NoError(t, err) + defer server.Close() tr2 := &quic.Transport{Conn: newUPDConnLocalhost(t)} defer tr2.Close() addTracer(tr2) - server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) - require.NoError(t, err) - defer server.Close() - type nonQUICPacket struct { b []byte addr net.Addr err error } - rcvdPackets := make(chan nonQUICPacket, 1024) + rcvdPackets := make(chan nonQUICPacket, numPackets) receiveCtx, receiveCancel := context.WithCancel(context.Background()) defer receiveCancel() // start receiving non-QUIC packets @@ -219,42 +219,6 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { } }() - // send a non-QUIC packet every 100µs - const packetLen = 128 - var sentPackets atomic.Int64 - errChan := make(chan error, 1) - done := make(chan struct{}) - stopSending := make(chan struct{}) - go func() { - defer close(done) - ticker := time.NewTicker(time.Millisecond / 10) - defer ticker.Stop() - - var wroteFirstPacket bool - for { - select { - case <-ticker.C: - b := make([]byte, packetLen) - rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet - _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) - // The first sendmsg call on a new UDP socket sometimes errors on Linux. - // It's not clear why this happens. - // See https://github.com/golang/go/issues/63322. - if err != nil && !wroteFirstPacket && runtime.GOOS == "linux" && isPermissionError(err) { - _, err = tr1.WriteTo(b, tr2.Conn.LocalAddr()) - } - if err != nil { - errChan <- err - return - } - sentPackets.Add(1) - wroteFirstPacket = true - case <-stopSending: - return - } - } - }() - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) defer cancel2() conn, err := tr2.Dial(ctx2, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) @@ -265,33 +229,101 @@ func TestMultiplexingNonQUICPackets(t *testing.T) { require.NoError(t, err) serverStr, err := serverConn.OpenUniStream() require.NoError(t, err) + + // send a non-QUIC packet every 100µs + const packetLen = 128 + errChanNonQUIC := make(chan error, 1) + sendNonQUICPacket := make(chan struct{}, 1) go func() { - defer serverStr.Close() - _, _ = serverStr.Write(PRData) + defer close(errChanNonQUIC) + var sentPackets int + for range sendNonQUICPacket { + b := make([]byte, packetLen) + rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet + _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) + // The first sendmsg call on a new UDP socket sometimes errors on Linux. + // It's not clear why this happens. + // See https://github.com/golang/go/issues/63322. + if err != nil && sentPackets == 0 && runtime.GOOS == "linux" && isPermissionError(err) { + _, err = tr1.WriteTo(b, tr2.Conn.LocalAddr()) + } + if err != nil { + errChanNonQUIC <- err + return + } + sentPackets++ + } }() - str, err := conn.AcceptUniStream(ctx2) - require.NoError(t, err) - data, err := io.ReadAll(str) - require.NoError(t, err) - require.Equal(t, PRData, data) + sendQUICPacket := make(chan struct{}, 1) + errChanQUIC := make(chan error, 1) + var dataSent []byte + go func() { + defer close(errChanQUIC) + defer serverStr.Close() + for range sendQUICPacket { + b := make([]byte, 1024) + rand.Read(b) + if _, err := serverStr.Write(b); err != nil { + errChanQUIC <- err + return + } + dataSent = append(dataSent, b...) + } + }() - // stop sending non-QUIC packets - close(stopSending) + dataChan := make(chan []byte, 1) + readErr := make(chan error, 1) + go func() { + str, err := conn.AcceptUniStream(ctx2) + if err != nil { + readErr <- err + return + } + data, err := io.ReadAll(str) + if err != nil { + readErr <- err + return + } + dataChan <- data + }() + + ticker := time.NewTicker(scaleDuration(200 * time.Microsecond)) + defer ticker.Stop() + for range numPackets { + sendNonQUICPacket <- struct{}{} + sendQUICPacket <- struct{}{} + <-ticker.C + } + close(sendNonQUICPacket) + close(sendQUICPacket) select { - case err := <-errChan: - t.Fatalf("error sending non-QUIC packets: %v", err) - case <-done: + case err := <-errChanNonQUIC: + require.NoError(t, err, "error sending non-QUIC packets") + case <-time.After(time.Second): + t.Fatalf("timeout waiting for non-QUIC packets to be sent") + } + select { + case err := <-errChanQUIC: + require.NoError(t, err, "error sending QUIC packets") + case <-time.After(time.Second): + t.Fatalf("timeout waiting for QUIC packets to be sent") + } + select { + case err := <-readErr: + require.NoError(t, err, "error reading stream data") + case dataRcvd := <-dataChan: + require.Equal(t, dataSent, dataRcvd, "stream data mismatch") + case <-time.After(time.Second): + t.Fatalf("timeout waiting for stream data to be read") } - sent := int(sentPackets.Load()) - require.Greater(t, sent, 10, "not enough non-QUIC packets sent: %d", sent) // make sure we don't overflow the capacity of the channel - require.LessOrEqual(t, sent, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", sent, cap(rcvdPackets)) + require.LessOrEqual(t, numPackets, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", numPackets, cap(rcvdPackets)) // now receive these packets - minExpected := sent * 4 / 5 + minExpected := numPackets * 4 / 5 timeout := time.After(time.Second) var counter int for {