forked from quic-go/quic-go
348 lines
9.1 KiB
Go
348 lines
9.1 KiB
Go
package self_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
mrand "math/rand/v2"
|
|
"net"
|
|
"runtime"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func runMultiplexTestServer(t *testing.T, ln *quic.Listener) {
|
|
t.Helper()
|
|
for {
|
|
conn, err := ln.Accept(context.Background())
|
|
if err != nil {
|
|
return
|
|
}
|
|
str, err := conn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
go func() {
|
|
defer str.Close()
|
|
_, err = str.Write(PRData)
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
t.Cleanup(func() { conn.CloseWithError(0, "") })
|
|
}
|
|
}
|
|
|
|
func dialAndReceiveData(tr *quic.Transport, addr net.Addr) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn, err := tr.Dial(ctx, addr, getTLSClientConfig(), getQuicConfig(nil))
|
|
if err != nil {
|
|
return fmt.Errorf("error dialing: %w", err)
|
|
}
|
|
str, err := conn.AcceptUniStream(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("error accepting stream: %w", err)
|
|
}
|
|
data, err := io.ReadAll(str)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading data: %w", err)
|
|
}
|
|
if !bytes.Equal(data, PRData) {
|
|
return fmt.Errorf("data mismatch: got %q, expected %q", data, PRData)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestMultiplexesConnectionsToSameServer(t *testing.T) {
|
|
server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
go runMultiplexTestServer(t, server)
|
|
|
|
tr := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
errChan1 := make(chan error, 1)
|
|
go func() { errChan1 <- dialAndReceiveData(tr, server.Addr()) }()
|
|
errChan2 := make(chan error, 1)
|
|
go func() { errChan2 <- dialAndReceiveData(tr, server.Addr()) }()
|
|
|
|
select {
|
|
case err := <-errChan1:
|
|
require.NoError(t, err, "error dialing server 1")
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for done1 to close")
|
|
}
|
|
select {
|
|
case err := <-errChan2:
|
|
require.NoError(t, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for done2 to close")
|
|
}
|
|
}
|
|
|
|
func TestMultiplexingToDifferentServers(t *testing.T) {
|
|
server1, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server1.Close()
|
|
go runMultiplexTestServer(t, server1)
|
|
|
|
server2, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server2.Close()
|
|
go runMultiplexTestServer(t, server2)
|
|
|
|
tr := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
errChan1 := make(chan error, 1)
|
|
go func() { errChan1 <- dialAndReceiveData(tr, server1.Addr()) }()
|
|
errChan2 := make(chan error, 1)
|
|
go func() { errChan2 <- dialAndReceiveData(tr, server2.Addr()) }()
|
|
|
|
select {
|
|
case err := <-errChan1:
|
|
require.NoError(t, err, "error dialing server 1")
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for done1 to close")
|
|
}
|
|
select {
|
|
case err := <-errChan2:
|
|
require.NoError(t, err, "error dialing server 2")
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for done2 to close")
|
|
}
|
|
}
|
|
|
|
func TestMultiplexingConnectToSelf(t *testing.T) {
|
|
tr := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
addTracer(tr)
|
|
defer tr.Close()
|
|
|
|
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
go runMultiplexTestServer(t, server)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() { errChan <- dialAndReceiveData(tr, server.Addr()) }()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err, "error dialing server")
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for connection to close")
|
|
}
|
|
}
|
|
|
|
func TestMultiplexingServerAndClientOnSameConn(t *testing.T) {
|
|
if runtime.GOOS == "linux" {
|
|
t.Skip("This test requires setting of iptables rules on Linux, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.")
|
|
}
|
|
|
|
tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
addTracer(tr1)
|
|
defer tr1.Close()
|
|
server1, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server1.Close()
|
|
|
|
tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
addTracer(tr2)
|
|
defer tr2.Close()
|
|
server2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server2.Close()
|
|
|
|
done1 := make(chan struct{})
|
|
go func() {
|
|
defer close(done1)
|
|
dialAndReceiveData(tr2, server1.Addr())
|
|
}()
|
|
|
|
done2 := make(chan struct{})
|
|
go func() {
|
|
defer close(done2)
|
|
dialAndReceiveData(tr1, server2.Addr())
|
|
}()
|
|
|
|
select {
|
|
case <-done1:
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("timeout waiting for done1 to close")
|
|
}
|
|
select {
|
|
case <-done2:
|
|
case <-time.After(time.Second):
|
|
t.Error("timeout waiting for done2 to close")
|
|
}
|
|
}
|
|
|
|
func TestMultiplexingNonQUICPackets(t *testing.T) {
|
|
const numPackets = 100
|
|
|
|
tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
defer tr1.Close()
|
|
addTracer(tr1)
|
|
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
|
|
defer tr2.Close()
|
|
addTracer(tr2)
|
|
|
|
type nonQUICPacket struct {
|
|
b []byte
|
|
addr net.Addr
|
|
err error
|
|
}
|
|
rcvdPackets := make(chan nonQUICPacket, numPackets)
|
|
receiveCtx, receiveCancel := context.WithCancel(context.Background())
|
|
defer receiveCancel()
|
|
// start receiving non-QUIC packets
|
|
go func() {
|
|
for {
|
|
b := make([]byte, 1024)
|
|
n, addr, err := tr2.ReadNonQUICPacket(receiveCtx, b)
|
|
if errors.Is(err, context.Canceled) {
|
|
return
|
|
}
|
|
rcvdPackets <- nonQUICPacket{b: b[:n], addr: addr, err: err}
|
|
}
|
|
}()
|
|
|
|
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel2()
|
|
conn, err := tr2.Dial(ctx2, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
|
require.NoError(t, err)
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
serverConn, err := server.Accept(ctx2)
|
|
require.NoError(t, err)
|
|
serverStr, err := serverConn.OpenUniStream()
|
|
require.NoError(t, err)
|
|
|
|
// send a non-QUIC packet every 100µs
|
|
const packetLen = 128
|
|
errChanNonQUIC := make(chan error, 1)
|
|
sendNonQUICPacket := make(chan struct{}, 1)
|
|
go func() {
|
|
var seed [32]byte
|
|
rand.Read(seed[:])
|
|
random := mrand.NewChaCha8(seed)
|
|
defer close(errChanNonQUIC)
|
|
var sentPackets int
|
|
for range sendNonQUICPacket {
|
|
b := make([]byte, packetLen)
|
|
random.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
|
|
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
|
// The first sendmsg call on a new UDP socket sometimes errors on Linux.
|
|
// It's not clear why this happens.
|
|
// See https://github.com/golang/go/issues/63322.
|
|
if err != nil && sentPackets == 0 && runtime.GOOS == "linux" && isPermissionError(err) {
|
|
_, err = tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
|
}
|
|
if err != nil {
|
|
errChanNonQUIC <- err
|
|
return
|
|
}
|
|
sentPackets++
|
|
}
|
|
}()
|
|
|
|
sendQUICPacket := make(chan struct{}, 1)
|
|
errChanQUIC := make(chan error, 1)
|
|
var dataSent []byte
|
|
go func() {
|
|
defer close(errChanQUIC)
|
|
defer serverStr.Close()
|
|
|
|
var seed [32]byte
|
|
rand.Read(seed[:])
|
|
random := mrand.NewChaCha8(seed)
|
|
for range sendQUICPacket {
|
|
b := make([]byte, 1024)
|
|
random.Read(b)
|
|
if _, err := serverStr.Write(b); err != nil {
|
|
errChanQUIC <- err
|
|
return
|
|
}
|
|
dataSent = append(dataSent, b...)
|
|
}
|
|
}()
|
|
|
|
dataChan := make(chan []byte, 1)
|
|
readErr := make(chan error, 1)
|
|
go func() {
|
|
str, err := conn.AcceptUniStream(ctx2)
|
|
if err != nil {
|
|
readErr <- err
|
|
return
|
|
}
|
|
data, err := io.ReadAll(str)
|
|
if err != nil {
|
|
readErr <- err
|
|
return
|
|
}
|
|
dataChan <- data
|
|
}()
|
|
|
|
ticker := time.NewTicker(scaleDuration(200 * time.Microsecond))
|
|
defer ticker.Stop()
|
|
for range numPackets {
|
|
sendNonQUICPacket <- struct{}{}
|
|
sendQUICPacket <- struct{}{}
|
|
<-ticker.C
|
|
}
|
|
close(sendNonQUICPacket)
|
|
close(sendQUICPacket)
|
|
|
|
select {
|
|
case err := <-errChanNonQUIC:
|
|
require.NoError(t, err, "error sending non-QUIC packets")
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for non-QUIC packets to be sent")
|
|
}
|
|
select {
|
|
case err := <-errChanQUIC:
|
|
require.NoError(t, err, "error sending QUIC packets")
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for QUIC packets to be sent")
|
|
}
|
|
select {
|
|
case err := <-readErr:
|
|
require.NoError(t, err, "error reading stream data")
|
|
case dataRcvd := <-dataChan:
|
|
require.Equal(t, dataSent, dataRcvd, "stream data mismatch")
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout waiting for stream data to be read")
|
|
}
|
|
|
|
// make sure we don't overflow the capacity of the channel
|
|
require.LessOrEqual(t, numPackets, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", numPackets, cap(rcvdPackets))
|
|
|
|
// now receive these packets
|
|
minExpected := numPackets * 4 / 5
|
|
timeout := time.After(time.Second)
|
|
var counter int
|
|
for counter < minExpected {
|
|
select {
|
|
case p := <-rcvdPackets:
|
|
require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address")
|
|
require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length")
|
|
require.NoError(t, p.err, "error receiving non-QUIC packet")
|
|
counter++
|
|
case <-timeout:
|
|
t.Fatalf("didn't receive enough non-QUIC packets: %d < %d", counter, minExpected)
|
|
}
|
|
}
|
|
}
|