migrate integration tests away from Ginkgo (#4736)

* use require in benchmark tests

* translate the QLOGDIR test

* translate handshake tests

* translate the handshake RTT tests

* translate the early data test

* translate the MTU tests

* translate the key update test

* translate the stateless reset tests

* translate the packetization test

* translate the close test

* translate the resumption test

* translate the tracer test

* translate the connection ID length test

* translate the RTT tests

* translate the multiplexing tests

* translate the drop tests

* translate the handshake drop tests

* translate the 0-RTT tests

* translate the hotswap test

* translate the stream test

* translate the unidirectional stream test

* translate the timeout tests

* translate the MITM test

* rewrite the datagram tests

* translate the cancellation tests

* translate the deadline tests

* translate the test helpers
This commit is contained in:
Marten Seemann
2024-12-16 23:43:59 +08:00
committed by GitHub
parent 5e198b0db1
commit 691086db7f
29 changed files with 5775 additions and 6143 deletions

View File

@@ -7,15 +7,15 @@ import (
"testing"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
func BenchmarkHandshake(b *testing.B) {
b.ReportAllocs()
ln, err := quic.ListenAddr("localhost:0", tlsConfig, nil)
if err != nil {
b.Fatal(err)
}
require.NoError(b, err)
defer ln.Close()
connChan := make(chan quic.Connection, 1)
@@ -30,9 +30,7 @@ func BenchmarkHandshake(b *testing.B) {
}()
conn, err := net.ListenUDP("udp", nil)
if err != nil {
b.Fatal(err)
}
require.NoError(b, err)
defer conn.Close()
tr := &quic.Transport{Conn: conn}
defer tr.Close()
@@ -40,10 +38,9 @@ func BenchmarkHandshake(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c, err := tr.Dial(context.Background(), ln.Addr(), tlsClientConfig, nil)
if err != nil {
b.Fatal(err)
}
<-connChan
require.NoError(b, err)
serverConn := <-connChan
serverConn.CloseWithError(0, "")
c.CloseWithError(0, "")
}
}
@@ -52,21 +49,20 @@ func BenchmarkStreamChurn(b *testing.B) {
b.ReportAllocs()
ln, err := quic.ListenAddr("localhost:0", tlsConfig, &quic.Config{MaxIncomingStreams: 1e10})
if err != nil {
b.Fatal(err)
}
require.NoError(b, err)
defer ln.Close()
errChan := make(chan error, 1)
conn, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil)
require.NoError(b, err)
defer conn.CloseWithError(0, "")
serverConn, err := ln.Accept(context.Background())
require.NoError(b, err)
defer serverConn.CloseWithError(0, "")
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
errChan <- err
return
}
close(errChan)
for {
str, err := conn.AcceptStream(context.Background())
str, err := serverConn.AcceptStream(context.Background())
if err != nil {
return
}
@@ -74,22 +70,10 @@ func BenchmarkStreamChurn(b *testing.B) {
}
}()
c, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil)
if err != nil {
b.Fatal(err)
}
if err := <-errChan; err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
str, err := c.OpenStreamSync(context.Background())
if err != nil {
b.Fatal(err)
}
if err := str.Close(); err != nil {
b.Fatal(err)
}
str, err := conn.OpenStreamSync(context.Background())
require.NoError(b, err)
require.NoError(b, str.Close())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,74 +5,80 @@ import (
"fmt"
"net"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Connection ID lengths tests", func() {
It("retransmits the CONNECTION_CLOSE packet", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
func TestConnectionCloseRetransmission(t *testing.T) {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
defer server.Close()
var drop atomic.Bool
dropped := make(chan []byte, 100)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond // 10ms RTT
},
DropPacket: func(dir quicproxy.Direction, b []byte) bool {
if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing {
dropped <- b
return true
}
return false
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
sconn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
time.Sleep(100 * time.Millisecond)
drop.Store(true)
sconn.CloseWithError(1337, "closing")
// send 100 packets
for i := 0; i < 100; i++ {
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
}
// Expect retransmissions of the CONNECTION_CLOSE for the
// 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet)
Eventually(dropped).Should(HaveLen(8))
first := <-dropped
for len(dropped) > 0 {
Expect(<-dropped).To(Equal(first)) // these packets are all identical
}
var drop atomic.Bool
dropped := make(chan []byte, 100)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return 5 * time.Millisecond // 10ms RTT
},
DropPacket: func(dir quicproxy.Direction, b []byte) bool {
if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing {
dropped <- b
return true
}
return false
},
})
})
require.NoError(t, err)
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
sconn, err := server.Accept(context.Background())
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
drop.Store(true)
sconn.CloseWithError(1337, "closing")
// send 100 packets
for i := 0; i < 100; i++ {
str, err := conn.OpenStream()
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Expect retransmissions of the CONNECTION_CLOSE for the
// 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet)
var packets [][]byte
for i := 0; i < 8; i++ {
select {
case p := <-dropped:
packets = append(packets, p)
case <-time.After(time.Second):
t.Fatal("timeout waiting for CONNECTION_CLOSE retransmission")
}
}
// verify all retransmitted packets were identical
for i := 1; i < len(packets); i++ {
require.Equal(t, packets[0], packets[i])
}
}

View File

@@ -7,124 +7,118 @@ import (
"io"
mrand "math/rand"
"net"
"testing"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
type connIDGenerator struct {
length int
Length int
}
var _ quic.ConnectionIDGenerator = &connIDGenerator{}
func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) {
b := make([]byte, c.length)
b := make([]byte, c.Length)
if _, err := rand.Read(b); err != nil {
fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
return quic.ConnectionID{}, fmt.Errorf("generating conn ID failed: %w", err)
}
return protocol.ParseConnectionID(b), nil
}
func (c *connIDGenerator) ConnectionIDLen() int {
return c.length
func (c *connIDGenerator) ConnectionIDLen() int { return c.Length }
func randomConnIDLen() int { return 2 + int(mrand.Int31n(19)) }
func TestConnectionIDsZeroLength(t *testing.T) {
testTransferWithConnectionIDs(t, randomConnIDLen(), 0, nil, nil)
}
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
func TestConnectionIDsRandomLengths(t *testing.T) {
testTransferWithConnectionIDs(t, randomConnIDLen(), randomConnIDLen(), nil, nil)
}
// connIDLen is ignored when connIDGenerator is set
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
}
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
}()
}
}()
return ln, func() {
ln.Close()
tr.Close()
}
func TestConnectionIDsCustomGenerator(t *testing.T) {
testTransferWithConnectionIDs(t, 0, 0,
&connIDGenerator{Length: randomConnIDLen()},
&connIDGenerator{Length: randomConnIDLen()},
)
}
// connIDLen is ignored when connIDGenerator is set
func testTransferWithConnectionIDs(
t *testing.T,
serverConnIDLen, clientConnIDLen int,
serverConnIDGenerator, clientConnIDGenerator quic.ConnectionIDGenerator,
) {
t.Helper()
if serverConnIDGenerator != nil {
t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen())
} else {
t.Logf("using %d byte connection ID for the server", serverConnIDLen)
}
if clientConnIDGenerator != nil {
t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen())
} else {
t.Logf("using %d byte connection ID for the client", clientConnIDLen)
}
// connIDLen is ignored when connIDGenerator is set
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
}
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
defer tr.Close()
cl, err := tr.Dial(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer cl.CloseWithError(0, "")
str, err := cl.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
// setup server
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
require.NoError(t, err)
conn, err := net.ListenUDP("udp", addr)
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
serverTr := &quic.Transport{
Conn: conn,
ConnectionIDLength: serverConnIDLen,
ConnectionIDGenerator: serverConnIDGenerator,
}
t.Cleanup(func() { serverTr.Close() })
addTracer(serverTr)
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
It("downloads a file using a 0-byte connection ID for the client", func() {
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), 0, nil)
})
// setup client
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
require.NoError(t, err)
clientConn, err := net.ListenUDP("udp", laddr)
require.NoError(t, err)
t.Cleanup(func() { clientConn.Close() })
clientTr := &quic.Transport{
Conn: clientConn,
ConnectionIDLength: clientConnIDLen,
ConnectionIDGenerator: clientConnIDGenerator,
}
t.Cleanup(func() { clientTr.Close() })
addTracer(clientTr)
It("downloads a file when both client and server use a random connection ID length", func() {
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), randomConnIDLen(), nil)
})
cl, err := clientTr.Dial(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: ln.Addr().(*net.UDPAddr).Port},
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
t.Cleanup(func() { cl.CloseWithError(0, "") })
It("downloads a file when both client and server use a custom connection ID generator", func() {
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
defer closeFn()
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
})
})
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenStream()
require.NoError(t, err)
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
go func() {
serverStr.Write(PRData)
serverStr.Close()
}()
str, err := cl.AcceptStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRData, data)
}

View File

@@ -1,187 +1,244 @@
package self_test
import (
"bytes"
"context"
"encoding/binary"
"fmt"
mrand "math/rand"
mrand "math/rand/v2"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/wire"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var _ = Describe("Datagram test", func() {
const concurrentSends = 100
const maxDatagramSize = 250
func TestDatagramNegotiation(t *testing.T) {
t.Run("server enable, client enable", func(t *testing.T) {
testDatagramNegotiation(t, true, true)
})
t.Run("server enable, client disable", func(t *testing.T) {
testDatagramNegotiation(t, true, false)
})
t.Run("server disable, client enable", func(t *testing.T) {
testDatagramNegotiation(t, false, true)
})
t.Run("server disable, client disable", func(t *testing.T) {
testDatagramNegotiation(t, false, false)
})
}
var (
serverConn, clientConn *net.UDPConn
dropped, total atomic.Int32
func testDatagramNegotiation(t *testing.T, serverEnableDatagram, clientEnableDatagram bool) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer udpConn.Close()
server, err := quic.Listen(
udpConn,
getTLSConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: serverEnableDatagram}),
)
require.NoError(t, err)
defer server.Close()
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
ln, err := quic.Listen(
serverConn,
getTLSConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: enableDatagram}),
)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
clientConn, err := quic.DialAddr(
ctx,
server.Addr().String(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: clientEnableDatagram}),
)
require.NoError(t, err)
defer clientConn.CloseWithError(0, "")
accepted := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(accepted)
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
if expectDatagramSupport {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
if enableDatagram {
f := &wire.DatagramFrame{DataLenPresent: true}
var wg sync.WaitGroup
wg.Add(concurrentSends)
for i := 0; i < concurrentSends; i++ {
go func(i int) {
defer GinkgoRecover()
defer wg.Done()
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(i))
Expect(conn.SendDatagram(b)).To(Succeed())
}(i)
}
maxDatagramMessageSize := f.MaxDataLen(maxDatagramSize, conn.ConnectionState().Version)
b := make([]byte, maxDatagramMessageSize+1)
Expect(conn.SendDatagram(b)).To(MatchError(&quic.DatagramTooLargeError{
MaxDatagramPayloadSize: int64(maxDatagramMessageSize),
}))
wg.Wait()
}
} else {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
}
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
// drop 10% of Short Header packets sent from the server
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if dir == quicproxy.DirectionIncoming {
return false
}
// don't drop Long Header packets
if wire.IsLongHeaderPacket(packet[0]) {
return false
}
drop := mrand.Int()%10 == 0
if drop {
dropped.Add(1)
}
total.Add(1)
return drop
},
})
Expect(err).ToNot(HaveOccurred())
return proxy.LocalPort(), func() {
Eventually(accepted).Should(BeClosed())
proxy.Close()
ln.Close()
}
if clientEnableDatagram {
require.True(t, serverConn.ConnectionState().SupportsDatagrams)
require.NoError(t, serverConn.SendDatagram([]byte("foo")))
datagram, err := clientConn.ReceiveDatagram(ctx)
require.NoError(t, err)
require.Equal(t, []byte("foo"), datagram)
} else {
require.False(t, serverConn.ConnectionState().SupportsDatagrams)
require.Error(t, serverConn.SendDatagram([]byte("foo")))
}
BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
clientConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})
if serverEnableDatagram {
require.True(t, clientConn.ConnectionState().SupportsDatagrams)
require.NoError(t, clientConn.SendDatagram([]byte("bar")))
datagram, err := serverConn.ReceiveDatagram(ctx)
require.NoError(t, err)
require.Equal(t, []byte("bar"), datagram)
} else {
require.False(t, clientConn.ConnectionState().SupportsDatagrams)
require.Error(t, clientConn.SendDatagram([]byte("bar")))
}
}
It("sends datagrams", func() {
oldMaxDatagramSize := wire.MaxDatagramSize
wire.MaxDatagramSize = maxDatagramSize
proxyPort, close := startServerAndProxy(true, true)
defer close()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
clientConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
var counter int
for {
// Close the connection if no message is received for 100 ms.
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") })
if _, err := conn.ReceiveDatagram(context.Background()); err != nil {
break
func TestDatagramSizeLimit(t *testing.T) {
const maxDatagramSize = 456
originalMaxDatagramSize := wire.MaxDatagramSize
wire.MaxDatagramSize = maxDatagramSize
t.Cleanup(func() { wire.MaxDatagramSize = originalMaxDatagramSize })
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer udpConn.Close()
server, err := quic.Listen(
udpConn,
getTLSConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
require.NoError(t, err)
defer server.Close()
clientConn, err := quic.DialAddr(
context.Background(),
server.Addr().String(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
require.NoError(t, err)
defer clientConn.CloseWithError(0, "")
err = clientConn.SendDatagram(bytes.Repeat([]byte("a"), maxDatagramSize+100)) // definitely too large
require.Error(t, err)
var sizeErr *quic.DatagramTooLargeError
require.ErrorAs(t, err, &sizeErr)
require.InDelta(t, sizeErr.MaxDatagramPayloadSize, maxDatagramSize, 10)
require.NoError(t, clientConn.SendDatagram(bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize))))
require.Error(t, clientConn.SendDatagram(bytes.Repeat([]byte("c"), int(sizeErr.MaxDatagramPayloadSize+1))))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
datagram, err := serverConn.ReceiveDatagram(ctx)
require.NoError(t, err)
require.Equal(t, bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize)), datagram)
}
func TestDatagramLoss(t *testing.T) {
const rtt = 10 * time.Millisecond
const numDatagrams = 100
const datagramSize = 500
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer udpConn.Close()
server, err := quic.Listen(
udpConn,
getTLSConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
require.NoError(t, err)
defer server.Close()
var droppedIncoming, droppedOutgoing, total atomic.Int32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
return false
}
timer.Stop()
counter++
if len(packet) < datagramSize { // don't drop ACK-only packets
return false
}
total.Add(1)
if mrand.Int()%10 == 0 {
switch dir {
case quicproxy.DirectionIncoming:
droppedIncoming.Add(1)
case quicproxy.DirectionOutgoing:
droppedOutgoing.Add(1)
}
return true
}
return false
},
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
defer proxy.Close()
clientConn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
require.NoError(t, err)
defer clientConn.CloseWithError(0, "")
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(numDatagrams*time.Millisecond))
defer cancel()
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
defer serverConn.CloseWithError(0, "")
var clientDatagrams, serverDatagrams int
clientErrChan := make(chan error, 1)
go func() {
defer close(clientErrChan)
for {
if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
clientErrChan <- err
return
}
clientDatagrams++
}
}()
numDropped := int(dropped.Load())
expVal := concurrentSends - numDropped
fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, total.Load())
fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, concurrentSends)
Expect(counter).To(And(
BeNumerically(">", expVal*9/10),
BeNumerically("<", concurrentSends),
))
Eventually(conn.Context().Done).Should(BeClosed())
wire.MaxDatagramSize = oldMaxDatagramSize
})
for i := 0; i < numDatagrams; i++ {
payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
require.NoError(t, clientConn.SendDatagram(payload))
require.NoError(t, serverConn.SendDatagram(payload))
time.Sleep(scaleDuration(time.Millisecond / 2))
}
It("server can disable datagram", func() {
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
clientConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
for {
if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
serverErrChan <- err
return
}
serverDatagrams++
}
}()
close()
conn.CloseWithError(0, "")
})
select {
case err := <-clientErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
t.Fatal("timeout")
}
select {
case err := <-serverErrChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
t.Fatal("timeout")
}
It("client can disable datagram", func() {
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
clientConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
Expect(conn.SendDatagram([]byte{0})).To(HaveOccurred())
close()
conn.CloseWithError(0, "")
})
})
numDroppedIncoming := droppedIncoming.Load()
numDroppedOutgoing := droppedOutgoing.Load()
t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load())
assert.NotZero(t, numDroppedIncoming)
assert.NotZero(t, numDroppedOutgoing)
t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
assert.InDelta(t, numDatagrams-numDroppedIncoming, serverDatagrams, numDatagrams/20, "datagrams received by the server")
t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
assert.InDelta(t, numDatagrams-numDroppedOutgoing, clientDatagrams, numDatagrams/20, "datagrams received by the client")
}

View File

@@ -1,217 +1,240 @@
package self_test
import (
"bytes"
"context"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Stream deadline tests", func() {
setup := func() (serverStr, clientStr quic.Stream, close func()) {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
strChan := make(chan quic.SendStream)
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
_, err = str.Read([]byte{0})
Expect(err).ToNot(HaveOccurred())
strChan <- str
}()
func setupDeadlineTest(t *testing.T) (serverStr, clientStr quic.Stream) {
t.Helper()
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
t.Cleanup(func() { server.Close() })
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
clientStr, err = conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
Expect(err).ToNot(HaveOccurred())
Eventually(strChan).Should(Receive(&serverStr))
return serverStr, clientStr, func() {
Expect(server.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := quic.DialAddr(
ctx,
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
t.Cleanup(func() { conn.CloseWithError(0, "") })
clientStr, err = conn.OpenStream()
require.NoError(t, err)
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
require.NoError(t, err)
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
serverStr, err = serverConn.AcceptStream(ctx)
require.NoError(t, err)
_, err = serverStr.Read([]byte{0})
require.NoError(t, err)
return serverStr, clientStr
}
func TestReadDeadlineSync(t *testing.T) {
serverStr, clientStr := setupDeadlineTest(t)
const timeout = time.Millisecond
errChan := make(chan error, 1)
go func() {
_, err := serverStr.Write(PRDataLong)
errChan <- err
}()
var bytesRead int
var timeoutCounter int
buf := make([]byte, 1<<10)
data := make([]byte, len(PRDataLong))
clientStr.SetReadDeadline(time.Now().Add(timeout))
for bytesRead < len(PRDataLong) {
n, err := clientStr.Read(buf)
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
clientStr.SetReadDeadline(time.Now().Add(timeout))
} else {
require.NoError(t, err)
}
copy(data[bytesRead:], buf[:n])
bytesRead += n
}
require.Equal(t, PRDataLong, data)
// make sure the test actually worked and Read actually ran into the deadline a few times
t.Logf("ran into deadline %d times", timeoutCounter)
require.GreaterOrEqual(t, timeoutCounter, 10)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestReadDeadlineAsync(t *testing.T) {
serverStr, clientStr := setupDeadlineTest(t)
const timeout = time.Millisecond
errChan := make(chan error, 1)
go func() {
_, err := serverStr.Write(PRDataLong)
errChan <- err
}()
var bytesRead int
var timeoutCounter int
buf := make([]byte, 1<<10)
data := make([]byte, len(PRDataLong))
received := make(chan struct{})
go func() {
for {
select {
case <-received:
return
default:
time.Sleep(timeout)
}
clientStr.SetReadDeadline(time.Now().Add(timeout))
}
}()
for bytesRead < len(PRDataLong) {
n, err := clientStr.Read(buf)
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
} else {
require.NoError(t, err)
}
copy(data[bytesRead:], buf[:n])
bytesRead += n
}
Context("read deadlines", func() {
It("completes a transfer when the deadline is set", func() {
serverStr, clientStr, closeFn := setup()
defer closeFn()
require.Equal(t, PRDataLong, data)
close(received)
const timeout = time.Millisecond
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serverStr.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
// make sure the test actually worked and Read actually ran into the deadline a few times
t.Logf("ran into deadline %d times", timeoutCounter)
require.GreaterOrEqual(t, timeoutCounter, 10)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
var bytesRead int
var timeoutCounter int
buf := make([]byte, 1<<10)
data := make([]byte, len(PRDataLong))
clientStr.SetReadDeadline(time.Now().Add(timeout))
for bytesRead < len(PRDataLong) {
n, err := clientStr.Read(buf)
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
clientStr.SetReadDeadline(time.Now().Add(timeout))
} else {
Expect(err).ToNot(HaveOccurred())
}
copy(data[bytesRead:], buf[:n])
bytesRead += n
}
Expect(data).To(Equal(PRDataLong))
// make sure the test actually worked and Read actually ran into the deadline a few times
Expect(timeoutCounter).To(BeNumerically(">=", 10))
Eventually(done).Should(BeClosed())
})
func TestWriteDeadlineSync(t *testing.T) {
serverStr, clientStr := setupDeadlineTest(t)
It("completes a transfer when the deadline is set concurrently", func() {
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
const timeout = time.Millisecond
go func() {
defer GinkgoRecover()
_, err := serverStr.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred())
}()
errChan := make(chan error, 1)
go func() {
defer close(errChan)
data, err := io.ReadAll(serverStr)
if err != nil {
errChan <- err
}
if !bytes.Equal(PRDataLong, data) {
errChan <- fmt.Errorf("data mismatch")
}
}()
var bytesRead int
var timeoutCounter int
buf := make([]byte, 1<<10)
data := make([]byte, len(PRDataLong))
clientStr.SetReadDeadline(time.Now().Add(timeout))
deadlineDone := make(chan struct{})
received := make(chan struct{})
go func() {
defer close(deadlineDone)
for {
select {
case <-received:
return
default:
time.Sleep(timeout)
}
clientStr.SetReadDeadline(time.Now().Add(timeout))
}
}()
for bytesRead < len(PRDataLong) {
n, err := clientStr.Read(buf)
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
} else {
Expect(err).ToNot(HaveOccurred())
}
copy(data[bytesRead:], buf[:n])
bytesRead += n
}
close(received)
Expect(data).To(Equal(PRDataLong))
// make sure the test actually worked an Read actually ran into the deadline a few times
Expect(timeoutCounter).To(BeNumerically(">=", 10))
Eventually(deadlineDone).Should(BeClosed())
})
})
Context("write deadlines", func() {
It("completes a transfer when the deadline is set", func() {
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
done := make(chan struct{})
go func() {
defer GinkgoRecover()
data, err := io.ReadAll(serverStr)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong))
close(done)
}()
var bytesWritten int
var timeoutCounter int
var bytesWritten int
var timeoutCounter int
clientStr.SetWriteDeadline(time.Now().Add(timeout))
for bytesWritten < len(PRDataLong) {
n, err := clientStr.Write(PRDataLong[bytesWritten:])
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
clientStr.SetWriteDeadline(time.Now().Add(timeout))
for bytesWritten < len(PRDataLong) {
n, err := clientStr.Write(PRDataLong[bytesWritten:])
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
clientStr.SetWriteDeadline(time.Now().Add(timeout))
} else {
Expect(err).ToNot(HaveOccurred())
}
bytesWritten += n
} else {
require.NoError(t, err)
}
bytesWritten += n
}
clientStr.Close()
// make sure the test actually worked and Write actually ran into the deadline a few times
t.Logf("ran into deadline %d times", timeoutCounter)
require.GreaterOrEqual(t, timeoutCounter, 10)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestWriteDeadlineAsync(t *testing.T) {
serverStr, clientStr := setupDeadlineTest(t)
const timeout = time.Millisecond
errChan := make(chan error, 1)
go func() {
defer close(errChan)
data, err := io.ReadAll(serverStr)
if err != nil {
errChan <- err
}
if !bytes.Equal(PRDataLong, data) {
errChan <- fmt.Errorf("data mismatch")
}
}()
clientStr.SetWriteDeadline(time.Now().Add(timeout))
readDone := make(chan struct{})
deadlineDone := make(chan struct{})
go func() {
defer close(deadlineDone)
for {
select {
case <-readDone:
return
default:
time.Sleep(timeout)
}
clientStr.Close()
// make sure the test actually worked an Read actually ran into the deadline a few times
Expect(timeoutCounter).To(BeNumerically(">=", 10))
Eventually(done).Should(BeClosed())
})
It("completes a transfer when the deadline is set concurrently", func() {
serverStr, clientStr, closeFn := setup()
defer closeFn()
const timeout = time.Millisecond
readDone := make(chan struct{})
go func() {
defer GinkgoRecover()
data, err := io.ReadAll(serverStr)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong))
close(readDone)
}()
clientStr.SetWriteDeadline(time.Now().Add(timeout))
deadlineDone := make(chan struct{})
go func() {
defer close(deadlineDone)
for {
select {
case <-readDone:
return
default:
time.Sleep(timeout)
}
clientStr.SetWriteDeadline(time.Now().Add(timeout))
}
}()
}
}()
var bytesWritten int
var timeoutCounter int
clientStr.SetWriteDeadline(time.Now().Add(timeout))
for bytesWritten < len(PRDataLong) {
n, err := clientStr.Write(PRDataLong[bytesWritten:])
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
} else {
Expect(err).ToNot(HaveOccurred())
}
bytesWritten += n
}
clientStr.Close()
// make sure the test actually worked an Read actually ran into the deadline a few times
Expect(timeoutCounter).To(BeNumerically(">=", 10))
Eventually(readDone).Should(BeClosed())
Eventually(deadlineDone).Should(BeClosed())
})
})
})
var bytesWritten int
var timeoutCounter int
clientStr.SetWriteDeadline(time.Now().Add(timeout))
for bytesWritten < len(PRDataLong) {
n, err := clientStr.Write(PRDataLong[bytesWritten:])
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
timeoutCounter++
} else {
require.NoError(t, err)
}
bytesWritten += n
}
clientStr.Close()
close(readDone)
// make sure the test actually worked and Write actually ran into the deadline a few times
t.Logf("ran into deadline %d times", timeoutCounter)
require.GreaterOrEqual(t, timeoutCounter, 10)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}

View File

@@ -3,121 +3,92 @@ package self_test
import (
"context"
"fmt"
"math/rand"
"net"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/wire"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
func randomDuration(min, max time.Duration) time.Duration {
return min + time.Duration(rand.Int63n(int64(max-min)))
}
func TestDropTests(t *testing.T) {
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing} {
t.Run(fmt.Sprintf("in %s direction", direction), func(t *testing.T) {
const numMessages = 15
const rtt = 10 * time.Millisecond
var _ = Describe("Drop Tests", func() {
var (
proxy *quicproxy.QuicProxy
ln *quic.Listener
)
messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond)
dropDuration := randomDuration(messageInterval*3/2, 2*time.Second)
dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2)
t.Logf("sending a message every %s, %d times", messageInterval, numMessages)
t.Logf("dropping packets for %s, after a delay of %s", dropDuration, dropDelay)
startTime := time.Now()
startListenerAndProxy := func(dropCallback quicproxy.DropCallback) {
var err error
ln, err = quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond // 10ms RTT
},
DropPacket: dropCallback,
},
)
Expect(err).ToNot(HaveOccurred())
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(ln.Close()).To(Succeed())
})
for _, d := range directions {
direction := d
// The purpose of this test is to create a lot of tails, by sending 1 byte messages.
// The interval, the length of the drop period, and the time when the drop period starts are randomized.
// To cover different scenarios, repeat this test a few times.
for rep := 0; rep < 3; rep++ {
It(fmt.Sprintf("sends short messages, dropping packets in %s direction", direction), func() {
const numMessages = 15
messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond)
dropDuration := randomDuration(messageInterval*3/2, 2*time.Second)
dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2) // makes sure we don't interfere with the handshake
fmt.Fprintf(GinkgoWriter, "Sending a message every %s, %d times.\n", messageInterval, numMessages)
fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay)
startTime := time.Now()
var numDroppedPackets atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var numDroppedPackets atomic.Int32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
DropPacket: func(d quicproxy.Direction, b []byte) bool {
if !d.Is(direction) {
return false
}
if wire.IsLongHeaderPacket(b[0]) { // don't interfere with the handshake
return false
}
drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration))
if drop {
numDroppedPackets.Add(1)
}
return drop
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
for i := uint8(1); i <= numMessages; i++ {
n, err := str.Write([]byte{i})
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(1))
time.Sleep(messageInterval)
}
<-done
Expect(conn.CloseWithError(0, "")).To(Succeed())
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := uint8(1); i <= numMessages; i++ {
b := []byte{0}
n, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(1))
Expect(b[0]).To(Equal(i))
}
close(done)
numDropped := numDroppedPackets.Load()
fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped)
Expect(numDropped).To(BeNumerically(">", 0))
},
})
}
require.NoError(t, err)
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
for i := uint8(1); i <= numMessages; i++ {
if _, err := serverStr.Write([]byte{i}); err != nil {
errChan <- err
return
}
time.Sleep(messageInterval)
}
}()
str, err := conn.AcceptUniStream(context.Background())
require.NoError(t, err)
for i := uint8(1); i <= numMessages; i++ {
b := []byte{0}
n, err := str.Read(b)
require.NoError(t, err)
require.Equal(t, 1, n)
require.Equal(t, i, b[0])
}
numDropped := numDroppedPackets.Load()
t.Logf("dropped %d packets", numDropped)
require.NotZero(t, numDropped)
})
}
})
}

View File

@@ -5,64 +5,71 @@ import (
"fmt"
"io"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("early data", func() {
func TestEarlyData(t *testing.T) {
const rtt = 80 * time.Millisecond
ln, err := quic.ListenAddrEarly("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
It("sends 0.5-RTT data", func() {
ln, err := quic.ListenAddrEarly(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("early data"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
// make sure the Write finished before the handshake completed
Expect(conn.HandshakeComplete()).ToNot(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("early data")))
conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
})
require.NoError(t, err)
defer proxy.Close()
connChan := make(chan quic.EarlyConnection)
errChan := make(chan error)
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
errChan <- err
return
}
connChan <- conn
}()
clientConn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
var serverConn quic.EarlyConnection
select {
case serverConn = <-connChan:
case err := <-errChan:
t.Fatalf("error accepting connection: %s", err)
}
str, err := serverConn.OpenUniStream()
require.NoError(t, err)
_, err = str.Write([]byte("early data"))
require.NoError(t, err)
require.NoError(t, str.Close())
// the write should have completed before the handshake
select {
case <-serverConn.HandshakeComplete():
t.Fatal("handshake shouldn't be completed yet")
default:
}
clientStr, err := clientConn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(clientStr)
require.NoError(t, err)
require.Equal(t, []byte("early data"), data)
clientConn.CloseWithError(0, "")
<-serverConn.Context().Done()
}

View File

@@ -0,0 +1,264 @@
package self_test
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/require"
)
func TestHandshakeContextTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
defer cancel()
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
require.ErrorIs(t, <-errChan, context.DeadlineExceeded)
}
func TestHandshakeCancellationError(t *testing.T) {
ctx, cancel := context.WithCancelCause(context.Background())
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
cancel(errors.New("application cancelled"))
require.EqualError(t, <-errChan, "application cancelled")
}
func TestConnContextOnServerSide(t *testing.T) {
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
connContextChan := make(chan context.Context, 1)
streamContextChan := make(chan context.Context, 1)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
},
}
defer tr.Close()
server, err := tr.Listen(
&tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
tlsGetConfigForClientContextChan <- info.Context()
tlsConf := getTLSConfig()
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
tlsGetCertificateContextChan <- info.Context()
return &tlsConf.Certificates[0], nil
}
return tlsConf, nil
},
},
getQuicConfig(&quic.Config{
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
tracerContextChan <- ctx
return nil
},
}),
)
require.NoError(t, err)
defer server.Close()
c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
serverConn, err := server.Accept(context.Background())
require.NoError(t, err)
connContextChan <- serverConn.Context()
str, err := serverConn.OpenUniStream()
require.NoError(t, err)
streamContextChan <- str.Context()
str.Write([]byte{1, 2, 3})
_, err = c.AcceptUniStream(context.Background())
require.NoError(t, err)
c.CloseWithError(1337, "bye")
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
t.Helper()
var ctx context.Context
select {
case ctx = <-c:
case <-time.After(time.Second):
t.Fatal("timeout waiting for context")
}
val := ctx.Value("foo")
require.NotNil(t, val)
v := val.(string)
require.Equal(t, "bar", v)
select {
case <-ctx.Done():
case <-time.After(time.Second):
t.Fatal("timeout waiting for context to be done")
}
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
require.ErrorAs(t, ctxErr, &appErr)
require.Equal(t, quic.ApplicationErrorCode(1337), appErr.ErrorCode)
}
checkContext(connContextChan, true)
checkContext(tracerContextChan, true)
checkContext(streamContextChan, true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContext(tlsGetConfigForClientContextChan, false)
checkContext(tlsGetCertificateContextChan, false)
}
// Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully.
func TestConnContextFreshContext(t *testing.T) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
}
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
errChan := make(chan error, 1)
go func() {
conn, err := server.Accept(context.Background())
if err != nil {
errChan <- err
return
}
conn.CloseWithError(1337, "bye")
}()
c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
select {
case <-c.Context().Done():
case err := <-errChan:
t.Fatalf("accept failed: %v", err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestContextOnClientSide(t *testing.T) {
tlsServerConf := getTLSConfig()
tlsServerConf.ClientAuth = tls.RequestClientCert
server, err := quic.ListenAddr("localhost:0", tlsServerConf, getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
tlsContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
tlsConf := getTLSClientConfig()
tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
tlsContextChan <- info.Context()
return &tlsServerConf.Certificates[0], nil
}
ctx, cancel := context.WithCancel(context.WithValue(context.Background(), "foo", "bar")) //nolint:staticcheck
conn, err := quic.DialAddr(
ctx,
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(&quic.Config{
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
tracerContextChan <- ctx
return nil
},
}),
)
require.NoError(t, err)
cancel()
// Make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
select {
case <-conn.Context().Done():
t.Fatal("context should not be cancelled")
default:
}
checkContext := func(ctx context.Context, checkCancellationCause bool) {
t.Helper()
val := ctx.Value("foo")
require.NotNil(t, val)
require.Equal(t, "bar", val.(string))
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
require.ErrorAs(t, ctxErr, &appErr)
require.EqualValues(t, 1337, appErr.ErrorCode)
}
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
t.Helper()
var ctx context.Context
select {
case ctx = <-c:
case <-time.After(time.Second):
t.Fatal("timeout waiting for context")
}
checkContext(ctx, checkCancellationCause)
}
str, err := conn.OpenUniStream()
require.NoError(t, err)
conn.CloseWithError(1337, "bye")
checkContext(conn.Context(), true)
checkContext(str.Context(), true)
// crypto/tls cancels the context when the TLS handshake completes
checkContextFromChan(tlsContextChan, false)
checkContextFromChan(tracerContextChan, false)
}

View File

@@ -1,6 +1,7 @@
package self_test
import (
"bytes"
"context"
"crypto/tls"
"fmt"
@@ -9,6 +10,7 @@ import (
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
@@ -16,282 +18,266 @@ import (
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
"github.com/stretchr/testify/require"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (_ *quic.Listener, proxyPort int) {
t.Helper()
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
})
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
tr := &quic.Transport{
Conn: conn,
VerifySourceAddress: func(net.Addr) bool { return doRetry },
}
t.Cleanup(func() { tr.Close() })
ln, err := tr.Listen(tlsConf, conf)
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
type applicationProtocol struct {
name string
run func(ln *quic.Listener, port int)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DropPacket: dropCallback,
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
t.Cleanup(func() { proxy.Close() })
return ln, proxy.LocalPort()
}
var _ = Describe("Handshake drop tests", func() {
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, port int, timeout time.Duration, data []byte) {
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer str.Close()
_, err := str.Write(data)
errChan <- err
}()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
serverConn, err := ln.Accept(ctx)
require.NoError(t, err)
serverStr, err := serverConn.AcceptUniStream(ctx)
require.NoError(t, err)
b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
require.NoError(t, err)
require.Equal(t, b, data)
serverConn.CloseWithError(0, "")
}
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, port int, timeout time.Duration, data []byte) {
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
defer conn.CloseWithError(0, "")
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
str, err := conn.AcceptUniStream(ctx)
if err != nil {
errChan <- err
return
}
b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
if err != nil {
errChan <- err
return
}
if !bytes.Equal(b, data) {
errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
return
}
}()
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
_, err = serverStr.Write(data)
require.NoError(t, err)
require.NoError(t, serverStr.Close())
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
select {
case <-conn.Context().Done():
case <-time.After(timeout):
t.Fatal("server connection not closed")
}
}
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, port int, timeout time.Duration) {
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
conn.CloseWithError(0, "")
serverConn.CloseWithError(0, "")
}
func dropCallbackDropNthPacket(direction quicproxy.Direction, n int) quicproxy.DropCallback {
var incoming, outgoing atomic.Int32
return func(d quicproxy.Direction, packet []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = incoming.Add(1)
case quicproxy.DirectionOutgoing:
p = outgoing.Add(1)
}
return p == int32(n) && d.Is(direction)
}
}
func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallback {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var incoming, outgoing int
return func(d quicproxy.Direction, _ []byte) bool {
drop := mrand.Int63n(int64(3)) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d.Is(quicproxy.DirectionIncoming) {
if drop {
incoming++
if incoming > maxSequentiallyDropped {
drop = false
}
}
if !drop {
incoming = 0
}
}
if d.Is(quicproxy.DirectionOutgoing) {
if drop {
outgoing++
if outgoing > maxSequentiallyDropped {
drop = false
}
}
if !drop {
outgoing = 0
}
}
return drop
}
}
func TestHandshakeWithPacketLoss(t *testing.T) {
data := GeneratePRData(5000)
const timeout = 2 * time.Minute
const rtt = 20 * time.Millisecond
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
})
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: conn}
if doRetry {
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err = tr.Listen(tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: dropCallback,
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
return 10 * time.Millisecond
},
})
Expect(err).ToNot(HaveOccurred())
return ln, proxy.LocalPort(), func() {
ln.Close()
tr.Close()
conn.Close()
proxy.Close()
}
type dropPattern struct {
name string
fn quicproxy.DropCallback
}
clientSpeaksFirst := &applicationProtocol{
name: "client speaks first",
run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection)
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal(data))
serverConnChan <- conn
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
var serverConn quic.Connection
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
conn.CloseWithError(0, "")
serverConn.CloseWithError(0, "")
},
type serverConfig struct {
longCertChain bool
doRetry bool
}
serverSpeaksFirst := &applicationProtocol{
name: "server speaks first",
run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection)
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
serverConnChan <- conn
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal(data))
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} {
for _, dropPattern := range []dropPattern{
{name: "drop 1st packet", fn: dropCallbackDropNthPacket(direction, 1)},
{name: "drop 2nd packet", fn: dropCallbackDropNthPacket(direction, 2)},
{name: "drop 1/3 of packets", fn: dropCallbackDropOneThird(direction)},
} {
t.Run(fmt.Sprintf("%s in %s direction", dropPattern.name, direction), func(t *testing.T) {
for _, conf := range []serverConfig{
{longCertChain: false, doRetry: true},
{longCertChain: false, doRetry: false},
{longCertChain: true, doRetry: false},
} {
t.Run(fmt.Sprintf("retry: %t", conf.doRetry), func(t *testing.T) {
t.Run("client speaks first", func(t *testing.T) {
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, timeout, data)
})
var serverConn quic.Connection
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
conn.CloseWithError(0, "")
serverConn.CloseWithError(0, "")
},
}
t.Run("server speaks first", func(t *testing.T) {
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolServerSpeaksFirst(t, ln, proxyPort, timeout, data)
})
nobodySpeaks := &applicationProtocol{
name: "nobody speaks",
run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection)
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
}),
)
Expect(err).ToNot(HaveOccurred())
var serverConn quic.Connection
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
// both server and client accepted a connection. Close now.
conn.CloseWithError(0, "")
serverConn.CloseWithError(0, "")
},
}
for _, d := range directions {
direction := d
for _, dr := range []bool{true, false} {
doRetry := dr
desc := "when using Retry"
if !dr {
desc = "when not using Retry"
}
Context(desc, func() {
for _, lcc := range []bool{false, true} {
longCertChain := lcc
Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing atomic.Int32
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = incoming.Add(1)
case quicproxy.DirectionOutgoing:
p = outgoing.Add(1)
}
return p == 1 && d.Is(direction)
}, doRetry, longCertChain)
defer closeFn()
app.run(ln, proxyPort)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing atomic.Int32
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = incoming.Add(1)
case quicproxy.DirectionOutgoing:
p = outgoing.Add(1)
}
return p == 2 && d.Is(direction)
}, doRetry, longCertChain)
defer closeFn()
app.run(ln, proxyPort)
})
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
const maxSequentiallyDropped = 10
var mx sync.Mutex
var incoming, outgoing int
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
drop := mrand.Int63n(int64(3)) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
if d.Is(quicproxy.DirectionIncoming) {
if drop {
incoming++
if incoming > maxSequentiallyDropped {
drop = false
}
}
if !drop {
incoming = 0
}
}
if d.Is(quicproxy.DirectionOutgoing) {
if drop {
outgoing++
if outgoing > maxSequentiallyDropped {
drop = false
}
}
if !drop {
outgoing = 0
}
}
return drop
}, doRetry, longCertChain)
defer closeFn()
app.run(ln, proxyPort)
})
})
}
t.Run("nobody speaks", func(t *testing.T) {
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
dropTestProtocolNobodySpeaks(t, ln, proxyPort, timeout)
})
})
}
})
}
It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
defer func() {
wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
}()
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
wire.AdditionalTransportParametersClient = map[uint64][]byte{
// Avoid random collisions with the greased transport parameters.
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
}
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
if d == quicproxy.DirectionOutgoing {
return false
}
return mrand.Intn(3) == 0
}, false, false)
defer closeFn()
clientSpeaksFirst.run(ln, proxyPort)
})
}
})
}
func TestPostQuantumClientHello(t *testing.T) {
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
t.Cleanup(func() { wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient })
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
mrand.New(mrand.NewSource(time.Now().UnixNano())).Read(b)
wire.AdditionalTransportParametersClient = map[uint64][]byte{
// Avoid random collisions with the greased transport parameters.
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
}
ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false)
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, time.Minute, GeneratePRData(5000))
}

View File

@@ -6,192 +6,175 @@ import (
"fmt"
"io"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Handshake RTT tests", func() {
var (
proxy *quicproxy.QuicProxy
serverConfig *quic.Config
serverTLSConfig *tls.Config
func handshakeWithRTT(t *testing.T, serverAddr net.Addr, tlsConf *tls.Config, quicConf *quic.Config, rtt time.Duration) quic.Connection {
t.Helper()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr.String(),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
t.Cleanup(func() { proxy.Close() })
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
tlsConf,
quicConf,
)
require.NoError(t, err)
t.Cleanup(func() { conn.CloseWithError(0, "") })
return conn
}
func TestHandshakeRTTWithoutRetry(t *testing.T) {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
clientConfig := getQuicConfig(&quic.Config{
GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
require.False(t, info.AddrVerified)
return nil, nil
},
})
const rtt = 400 * time.Millisecond
start := time.Now()
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), clientConfig, rtt)
rtts := time.Since(start).Seconds() / rtt.Seconds()
require.GreaterOrEqual(t, rtts, float64(1))
require.Less(t, rtts, float64(2))
}
BeforeEach(func() {
serverConfig = getQuicConfig(nil)
serverTLSConfig = getTLSConfig()
func TestHandshakeRTTWithRetry(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
clientConfig := getQuicConfig(&quic.Config{
GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
require.True(t, info.AddrVerified)
return nil, nil
},
})
const rtt = 400 * time.Millisecond
start := time.Now()
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), clientConfig, rtt)
rtts := time.Since(start).Seconds() / rtt.Seconds()
require.GreaterOrEqual(t, rtts, float64(2))
require.Less(t, rtts, float64(3))
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
func TestHandshakeRTTWithHelloRetryRequest(t *testing.T) {
tlsConf := getTLSConfig()
tlsConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
runProxy := func(serverAddr net.Addr) {
var err error
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr.String(),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
ln, err := quic.ListenAddr("localhost:0", tlsConf, getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
const rtt = 400 * time.Millisecond
start := time.Now()
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
rtts := time.Since(start).Seconds() / rtt.Seconds()
require.GreaterOrEqual(t, rtts, float64(2))
require.Less(t, rtts, float64(3))
}
func TestHandshakeRTTReceiveMessage(t *testing.T) {
sendAndReceive := func(t *testing.T, serverConn, clientConn quic.Connection) {
t.Helper()
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
_, err = serverStr.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, serverStr.Close())
str, err := clientConn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), data)
}
expectDurationInRTTs := func(startTime time.Time, num int) {
testDuration := time.Since(startTime)
rtts := float32(testDuration) / float32(rtt)
Expect(rtts).To(SatisfyAll(
BeNumerically(">=", num),
BeNumerically("<", num+1),
))
}
t.Run("using ListenAddr", func(t *testing.T) {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs with Retry", func() {
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func(net.Addr) bool { return true },
connChan := make(chan quic.Connection, 1)
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
t.Logf("failed to accept connection: %s", err)
close(connChan)
return
}
connChan <- conn
}()
const rtt = 400 * time.Millisecond
start := time.Now()
conn := handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
serverConn := <-connChan
if serverConn == nil {
t.Fatal("serverConn is nil")
}
addTracer(tr)
defer tr.Close()
ln, err := tr.Listen(serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
sendAndReceive(t, serverConn, conn)
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
Expect(info.AddrVerified).To(BeTrue())
return nil, nil
}}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
rtts := time.Since(start).Seconds() / rtt.Seconds()
require.GreaterOrEqual(t, rtts, float64(2))
require.Less(t, rtts, float64(3))
})
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
t.Run("using ListenAddrEarly", func(t *testing.T) {
ln, err := quic.ListenAddrEarly("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
Expect(info.AddrVerified).To(BeFalse())
return nil, nil
}}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 1)
})
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
})
It("receives the first message from the server after 2 RTTs, when the server uses ListenAddr", func() {
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
connChan := make(chan quic.Connection, 1)
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
if err != nil {
t.Logf("failed to accept connection: %s", err)
close(connChan)
return
}
connChan <- conn
}()
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
expectDurationInRTTs(startTime, 2)
const rtt = 400 * time.Millisecond
start := time.Now()
conn := handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
serverConn := <-connChan
if serverConn == nil {
t.Fatal("serverConn is nil")
}
sendAndReceive(t, serverConn, conn)
took := time.Since(start)
rtts := float64(took) / float64(rtt)
require.GreaterOrEqual(t, rtts, float64(1))
require.Less(t, rtts, float64(2))
})
It("receives the first message from the server after 1 RTT, when the server uses ListenAddrEarly", func() {
ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// Check the ALPN now. This is probably what an application would do.
// It makes sure that ConnectionState does not block until the handshake completes.
Expect(conn.ConnectionState().TLS.NegotiatedProtocol).To(Equal(alpn))
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
expectDurationInRTTs(startTime, 1)
})
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,170 +0,0 @@
package self_test
import (
"context"
"io"
"net"
"net/http"
"strconv"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
type listenerWrapper struct {
http3.QUICEarlyListener
listenerClosed bool
count atomic.Int32
}
func (ln *listenerWrapper) Close() error {
ln.listenerClosed = true
return ln.QUICEarlyListener.Close()
}
func (ln *listenerWrapper) Faker() *fakeClosingListener {
ln.count.Add(1)
ctx, cancel := context.WithCancel(context.Background())
return &fakeClosingListener{
listenerWrapper: ln,
ctx: ctx,
cancel: cancel,
}
}
type fakeClosingListener struct {
*listenerWrapper
closed atomic.Bool
ctx context.Context
cancel context.CancelFunc
}
func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
return ln.listenerWrapper.Accept(ln.ctx)
}
func (ln *fakeClosingListener) Close() error {
if ln.closed.CompareAndSwap(false, true) {
ln.cancel()
if ln.listenerWrapper.count.Add(-1) == 0 {
ln.listenerWrapper.Close()
}
}
return nil
}
var _ = Describe("HTTP3 Server hotswap test", func() {
var (
mux1 *http.ServeMux
mux2 *http.ServeMux
client *http.Client
rt *http3.Transport
server1 *http3.Server
server2 *http3.Server
ln *listenerWrapper
port string
)
BeforeEach(func() {
mux1 = http.NewServeMux()
mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
})
mux2 = http.NewServeMux()
mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
})
server1 = &http3.Server{
Handler: mux1,
QUICConfig: getQuicConfig(nil),
}
server2 = &http3.Server{
Handler: mux2,
QUICConfig: getQuicConfig(nil),
}
tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
ln = &listenerWrapper{QUICEarlyListener: quicln}
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
Expect(rt.Close()).NotTo(HaveOccurred())
Expect(ln.Close()).NotTo(HaveOccurred())
})
BeforeEach(func() {
rt = &http3.Transport{
TLSClientConfig: getTLSClientConfig(),
DisableCompression: true,
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
}
client = &http.Client{Transport: rt}
})
It("hotswap works", func() {
// open first server and make single request to it
fake1 := ln.Faker()
stoppedServing1 := make(chan struct{})
go func() {
defer GinkgoRecover()
server1.ServeListener(fake1)
close(stoppedServing1)
}()
resp, err := client.Get("https://localhost:" + port + "/hello1")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World 1!\n"))
// open second server with same underlying listener,
// make sure it opened and both servers are currently running
fake2 := ln.Faker()
stoppedServing2 := make(chan struct{})
go func() {
defer GinkgoRecover()
server2.ServeListener(fake2)
close(stoppedServing2)
}()
Consistently(stoppedServing1).ShouldNot(BeClosed())
Consistently(stoppedServing2).ShouldNot(BeClosed())
// now close first server, no errors should occur here
// and only the fake listener should be closed
Expect(server1.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing1).Should(BeClosed())
Expect(fake1.closed.Load()).To(BeTrue())
Expect(fake2.closed.Load()).To(BeFalse())
Expect(ln.listenerClosed).ToNot(BeTrue())
Expect(client.Transport.(*http3.Transport).Close()).NotTo(HaveOccurred())
// verify that new connections are being initiated from the second server now
resp, err = client.Get("https://localhost:" + port + "/hello2")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World 2!\n"))
// close the other server - both the fake and the actual listeners must close now
Expect(server2.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing2).Should(BeClosed())
Expect(fake2.closed.Load()).To(BeTrue())
Expect(ln.listenerClosed).To(BeTrue())
})
})

View File

@@ -0,0 +1,160 @@
package self_test
import (
"context"
"io"
"net"
"net/http"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/stretchr/testify/require"
)
type listenerWrapper struct {
http3.QUICEarlyListener
listenerClosed bool
count atomic.Int32
}
func (ln *listenerWrapper) Close() error {
ln.listenerClosed = true
return ln.QUICEarlyListener.Close()
}
func (ln *listenerWrapper) Faker() *fakeClosingListener {
ln.count.Add(1)
ctx, cancel := context.WithCancel(context.Background())
return &fakeClosingListener{
listenerWrapper: ln,
ctx: ctx,
cancel: cancel,
}
}
type fakeClosingListener struct {
*listenerWrapper
closed atomic.Bool
ctx context.Context
cancel context.CancelFunc
}
func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
return ln.listenerWrapper.Accept(ln.ctx)
}
func (ln *fakeClosingListener) Close() error {
if ln.closed.CompareAndSwap(false, true) {
ln.cancel()
if ln.listenerWrapper.count.Add(-1) == 0 {
ln.listenerWrapper.Close()
}
}
return nil
}
func TestHTTP3ServerHotswap(t *testing.T) {
mux1 := http.NewServeMux()
mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
})
mux2 := http.NewServeMux()
mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
})
server1 := &http3.Server{
Handler: mux1,
QUICConfig: getQuicConfig(nil),
}
server2 := &http3.Server{
Handler: mux2,
QUICConfig: getQuicConfig(nil),
}
tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
require.NoError(t, err)
ln := &listenerWrapper{QUICEarlyListener: quicln}
port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
rt := &http3.Transport{
TLSClientConfig: getTLSClientConfig(),
DisableCompression: true,
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
}
client := &http.Client{Transport: rt}
defer func() {
require.NoError(t, rt.Close())
require.NoError(t, ln.Close())
}()
// open first server and make single request to it
fake1 := ln.Faker()
stoppedServing1 := make(chan struct{})
go func() {
server1.ServeListener(fake1)
close(stoppedServing1)
}()
resp, err := client.Get("https://localhost:" + port + "/hello1")
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, World 1!\n", string(body))
// open second server with same underlying listener
fake2 := ln.Faker()
stoppedServing2 := make(chan struct{})
go func() {
server2.ServeListener(fake2)
close(stoppedServing2)
}()
// Verify both servers are running by waiting a bit and checking channels aren't closed
time.Sleep(50 * time.Millisecond)
select {
case <-stoppedServing1:
t.Fatal("server1 stopped unexpectedly")
case <-stoppedServing2:
t.Fatal("server2 stopped unexpectedly")
default:
}
// now close first server
require.NoError(t, server1.Close())
select {
case <-stoppedServing1:
case <-time.After(time.Second):
t.Fatal("timed out waiting for server1 to stop")
}
require.True(t, fake1.closed.Load())
require.False(t, fake2.closed.Load())
require.False(t, ln.listenerClosed)
require.NoError(t, client.Transport.(*http3.Transport).Close())
// verify that new connections are being initiated from the second server now
resp, err = client.Get("https://localhost:" + port + "/hello2")
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, World 2!\n", string(body))
// close the other server - both the fake and the actual listeners must close now
require.NoError(t, server2.Close())
select {
case <-stoppedServing2:
case <-time.After(time.Second):
t.Fatal("timed out waiting for server2 to stop")
}
require.True(t, fake2.closed.Load())
require.True(t, ln.listenerClosed)
}

View File

@@ -5,88 +5,94 @@ import (
"fmt"
"io"
"net"
"testing"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var (
sentHeaders []*logging.ShortHeader
receivedHeaders []*logging.ShortHeader
)
func TestKeyUpdates(t *testing.T) {
origKeyUpdateInterval := handshake.KeyUpdateInterval
t.Cleanup(func() { handshake.KeyUpdateInterval = origKeyUpdateInterval })
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
func countKeyPhases() (sent, received int) {
lastKeyPhase := protocol.KeyPhaseOne
for _, hdr := range sentHeaders {
if hdr.KeyPhase != lastKeyPhase {
sent++
lastKeyPhase = hdr.KeyPhase
var sentHeaders []*logging.ShortHeader
var receivedHeaders []*logging.ShortHeader
countKeyPhases := func() (sent, received int) {
lastKeyPhase := protocol.KeyPhaseOne
for _, hdr := range sentHeaders {
if hdr.KeyPhase != lastKeyPhase {
sent++
lastKeyPhase = hdr.KeyPhase
}
}
}
lastKeyPhase = protocol.KeyPhaseOne
for _, hdr := range receivedHeaders {
if hdr.KeyPhase != lastKeyPhase {
received++
lastKeyPhase = hdr.KeyPhase
lastKeyPhase = protocol.KeyPhaseOne
for _, hdr := range receivedHeaders {
if hdr.KeyPhase != lastKeyPhase {
received++
lastKeyPhase = hdr.KeyPhase
}
}
return
}
return
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
require.NoError(t, err)
defer server.Close()
serverErrChan := make(chan error, 1)
go func() {
conn, err := server.Accept(context.Background())
if err != nil {
serverErrChan <- err
return
}
str, err := conn.OpenUniStream()
if err != nil {
serverErrChan <- err
return
}
defer str.Close()
if _, err := str.Write(PRDataLong); err != nil {
serverErrChan <- err
return
}
close(serverErrChan)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
},
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
receivedHeaders = append(receivedHeaders, hdr)
},
}
}}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRDataLong, data)
require.NoError(t, conn.CloseWithError(0, ""))
require.NoError(t, <-serverErrChan)
keyPhasesSent, keyPhasesReceived := countKeyPhases()
t.Logf("Used %d key phases on outgoing and %d key phases on incoming packets.", keyPhasesSent, keyPhasesReceived)
require.Greater(t, keyPhasesReceived, 10)
require.InDelta(t, keyPhasesSent, keyPhasesReceived, 2)
}
var keyUpdateConnTracer = &logging.ConnectionTracer{
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
},
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
receivedHeaders = append(receivedHeaders, hdr)
},
}
var _ = Describe("Key Update tests", func() {
It("downloads a large file", func() {
origKeyUpdateInterval := handshake.KeyUpdateInterval
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred())
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return keyUpdateConnTracer
}}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong))
Expect(conn.CloseWithError(0, "")).To(Succeed())
keyPhasesSent, keyPhasesReceived := countKeyPhases()
fmt.Fprintf(GinkgoWriter, "Used %d key phases on outgoing and %d key phases on incoming packets.\n", keyPhasesSent, keyPhasesReceived)
Expect(keyPhasesReceived).To(BeNumerically(">", 10))
Expect(keyPhasesReceived).To(BeNumerically("~", keyPhasesSent, 2))
})
})

View File

@@ -7,7 +7,9 @@ import (
"io"
"math"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/exp/rand"
@@ -18,455 +20,401 @@ import (
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/testutils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("MITM test", func() {
const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
const mitmTestConnIDLen = 6
var (
clientUDPConn net.PacketConn
serverTransport, clientTransport *quic.Transport
serverConn quic.Connection
serverConfig *quic.Config
)
func getTransportsForMITMTest(t *testing.T) (serverTransport, clientTransport *quic.Transport) {
serverUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
serverTransport = &quic.Transport{
Conn: serverUDPConn,
ConnectionIDLength: mitmTestConnIDLen,
}
addTracer(serverTransport)
t.Cleanup(func() { serverTransport.Close() })
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
serverTransport = &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
var err error
serverConn, err = ln.Accept(context.Background())
clientUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
clientTransport = &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: mitmTestConnIDLen,
}
addTracer(clientTransport)
t.Cleanup(func() { clientTransport.Close() })
return serverTransport, clientTransport
}
func TestMITMInjectRandomPackets(t *testing.T) {
t.Run("towards the server", func(t *testing.T) {
testMITMInjectRandomPackets(t, quicproxy.DirectionIncoming)
})
t.Run("towards the client", func(t *testing.T) {
testMITMInjectRandomPackets(t, quicproxy.DirectionOutgoing)
})
}
func TestMITMDuplicatePackets(t *testing.T) {
t.Run("towards the server", func(t *testing.T) {
testMITMDuplicatePackets(t, quicproxy.DirectionIncoming)
})
t.Run("towards the client", func(t *testing.T) {
testMITMDuplicatePackets(t, quicproxy.DirectionOutgoing)
})
}
func TestMITInjectCorruptedPackets(t *testing.T) {
t.Run("towards the server", func(t *testing.T) {
testMITMInjectCorruptedPackets(t, quicproxy.DirectionIncoming)
})
t.Run("towards the client", func(t *testing.T) {
testMITMInjectCorruptedPackets(t, quicproxy.DirectionOutgoing)
})
}
func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) {
createRandomPacketOfSameType := func(b []byte) []byte {
if wire.IsLongHeaderPacket(b[0]) {
hdr, _, _, err := wire.ParsePacket(b)
if err != nil {
return
return nil
}
str, err := serverConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: delayCb,
DropPacket: dropCb,
})
Expect(err).ToNot(HaveOccurred())
return proxy.LocalPort(), func() {
proxy.Close()
ln.Close()
serverTransport.Close()
<-done
replyHdr := &wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
}
payloadLen := rand.Int31n(100)
replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
data, err := replyHdr.Append(nil, hdr.Version)
if err != nil {
panic("failed to append header: " + err.Error())
}
r := make([]byte, payloadLen)
rand.Read(r)
return append(data, r...)
}
// short header packet
connID, err := wire.ParseConnectionID(b, mitmTestConnIDLen)
if err != nil {
return nil
}
_, pn, pnLen, _, err := wire.ParseShortHeader(b, mitmTestConnIDLen)
if err != nil && !errors.Is(err, wire.ErrInvalidReservedBits) { // normally, ParseShortHeader is called after decrypting the header
panic("failed to parse short header: " + err.Error())
}
data, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
if err != nil {
return nil
}
payloadLen := rand.Int31n(100)
r := make([]byte, payloadLen)
rand.Read(r)
return append(data, r...)
}
BeforeEach(func() {
serverConfig = getQuicConfig(nil)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
clientUDPConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
clientTransport = &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
rtt := scaleDuration(10 * time.Millisecond)
serverTransport, clientTransport := getTransportsForMITMTest(t)
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
if dir != direction {
return false
}
addTracer(clientTransport)
})
Context("unsuccessful attacks", func() {
AfterEach(func() {
Eventually(serverConn.Context().Done()).Should(BeClosed())
// Test shutdown is tricky due to the proxy. Just wait for a bit.
time.Sleep(50 * time.Millisecond)
Expect(clientUDPConn.Close()).To(Succeed())
Expect(clientTransport.Close()).To(Succeed())
})
Context("injecting invalid packets", func() {
const rtt = 20 * time.Millisecond
sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
defer GinkgoRecover()
const numPackets = 10
ticker := time.NewTicker(rtt / numPackets)
defer ticker.Stop()
if wire.IsLongHeaderPacket(raw[0]) {
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
replyHdr := &wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
}
for i := 0; i < numPackets; i++ {
payloadLen := rand.Int31n(100)
replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
b, err := replyHdr.Append(nil, hdr.Version)
Expect(err).ToNot(HaveOccurred())
r := make([]byte, payloadLen)
rand.Read(r)
b = append(b, r...)
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return
}
<-ticker.C
}
} else {
connID, err := wire.ParseConnectionID(raw, connIDLen)
Expect(err).ToNot(HaveOccurred())
_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
if err != nil { // normally, ParseShortHeader is called after decrypting the header
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}
for i := 0; i < numPackets; i++ {
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
Expect(err).ToNot(HaveOccurred())
payloadLen := rand.Int31n(100)
r := make([]byte, payloadLen)
rand.Read(r)
b = append(b, r...)
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return
}
<-ticker.C
}
go func() {
ticker := time.NewTicker(rtt / 10)
defer ticker.Stop()
for i := 0; i < 10; i++ {
switch direction {
case quicproxy.DirectionIncoming:
clientTransport.WriteTo(createRandomPacketOfSameType(b), serverTransport.Conn.LocalAddr())
case quicproxy.DirectionOutgoing:
serverTransport.WriteTo(createRandomPacketOfSameType(b), clientTransport.Conn.LocalAddr())
}
<-ticker.C
}
}()
return false
}
runTest := func(delayCb quicproxy.DelayCallback) {
proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
getQuicConfig(nil),
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
}
func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
if dir != direction {
return false
}
switch direction {
case quicproxy.DirectionIncoming:
clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr())
case quicproxy.DirectionOutgoing:
serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr())
}
return false
}
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
}
func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
var numCorrupted atomic.Int32
const interval = 4
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
if dir != direction {
return false
}
if rand.Intn(interval) == 0 {
numCorrupted.Add(1)
pos := rand.Intn(len(b))
b[pos] = byte(rand.Intn(256))
return true
}
switch direction {
case quicproxy.DirectionIncoming:
clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr())
case quicproxy.DirectionOutgoing:
serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr())
}
return false
}
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
require.NotZero(t, int(numCorrupted.Load()))
}
func runMITMTest(t *testing.T,
serverTransport, clientTransport *quic.Transport,
direction quicproxy.Direction,
rtt time.Duration,
dropCallback quicproxy.DropCallback,
) {
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
DropPacket: dropCallback,
})
require.NoError(t, err)
defer proxy.Close()
conn, err := clientTransport.Dial(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
if _, err := serverStr.Write(PRData); err != nil {
errChan <- err
return
}
serverStr.Close()
}()
require.NoError(t, <-errChan)
str, err := conn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRData, data)
require.NoError(t, conn.CloseWithError(0, ""))
}
func TestMITMForgedVersionNegotiationPacket(t *testing.T) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
const supportedVersion protocol.Version = 42
var once sync.Once
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir != quicproxy.DirectionIncoming {
return rtt / 2
}
once.Do(func() {
hdr, _, _, err := wire.ParsePacket(raw)
if err != nil {
panic("failed to parse packet: " + err.Error())
}
// create fake version negotiation packet with a fake supported version
packet := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
[]protocol.Version{supportedVersion},
)
if _, err := serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr()); err != nil {
panic("failed to write packet: " + err.Error())
}
})
return rtt / 2
}
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
var vnErr *quic.VersionNegotiationError
require.ErrorAs(t, err, &vnErr)
require.Contains(t, vnErr.Theirs, supportedVersion) // might contain greased versions
}
// times out, because client doesn't accept subsequent real retry packets from server
// as it has already accepted a retry.
// TODO: determine behavior when server does not send Retry packets
func TestMITMForgedRetryPacket(t *testing.T) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
rtt := scaleDuration(10 * time.Millisecond)
var once sync.Once
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
hdr, _, _, err := wire.ParsePacket(raw)
if err != nil {
panic("failed to parse packet: " + err.Error())
}
if dir == quicproxy.DirectionIncoming && hdr.Type == protocol.PacketTypeInitial {
once.Do(func() {
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
if _, err := serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr()); err != nil {
panic("failed to write packet: " + err.Error())
}
})
}
return rtt / 2
}
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
var nerr net.Error
require.ErrorAs(t, err, &nerr)
require.True(t, nerr.Timeout())
}
func TestMITMForgedInitialPacket(t *testing.T) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
var once sync.Once
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
hdr, _, _, err := wire.ParsePacket(raw)
if err != nil {
panic("failed to parse packet: " + err.Error())
}
if hdr.Type != protocol.PacketTypeInitial {
return 0
}
once.Do(func() {
initialPacket := testutils.ComposeInitialPacket(
hdr.DestConnectionID,
hdr.SrcConnectionID,
hdr.DestConnectionID,
nil,
nil,
protocol.PerspectiveServer,
hdr.Version,
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
Expect(conn.CloseWithError(0, "")).To(Succeed())
}
It("downloads a message when the packets are injected towards the server", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil {
panic("failed to write packet: " + err.Error())
}
runTest(delayCb)
})
It("downloads a message when the packets are injected towards the client", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionOutgoing {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
}
runTest(delayCb)
})
})
runTest := func(dropCb quicproxy.DropCallback) {
proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
Expect(conn.CloseWithError(0, "")).To(Succeed())
}
return rtt / 2
}
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
var nerr net.Error
require.ErrorAs(t, err, &nerr)
require.True(t, nerr.Timeout())
}
Context("duplicating packets", func() {
It("downloads a message when packets are duplicated towards the server", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionIncoming {
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
func TestMITMForgedInitialPacketWithAck(t *testing.T) {
serverTransport, clientTransport := getTransportsForMITMTest(t)
rtt := scaleDuration(10 * time.Millisecond)
var once sync.Once
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
hdr, _, _, err := wire.ParsePacket(raw)
if err != nil {
panic("failed to parse packet: " + err.Error())
}
if hdr.Type != protocol.PacketTypeInitial {
return 0
}
once.Do(func() {
// Fake Initial with ACK for packet 2 (unsent)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(
hdr.DestConnectionID,
hdr.SrcConnectionID,
hdr.DestConnectionID,
nil,
[]wire.Frame{ack},
protocol.PerspectiveServer,
hdr.Version,
)
if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil {
panic("failed to write packet: " + err.Error())
}
runTest(dropCb)
})
It("downloads a message when packets are duplicated towards the client", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionOutgoing {
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
}
runTest(dropCb)
})
})
Context("corrupting packets", func() {
const idleTimeout = time.Second
var numCorrupted, numPackets atomic.Int32
BeforeEach(func() {
numCorrupted.Store(0)
numPackets.Store(0)
serverConfig.MaxIdleTimeout = idleTimeout
})
AfterEach(func() {
num := numCorrupted.Load()
fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load())
Expect(num).To(BeNumerically(">=", 1))
// If the packet containing the CONNECTION_CLOSE is corrupted,
// we have to wait for the connection to time out.
Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
})
It("downloads a message when packet are corrupted towards the server", func() {
const interval = 4 // corrupt every 4th packet (stochastically)
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionIncoming {
numPackets.Add(1)
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
numCorrupted.Add(1)
return true
}
}
return false
}
runTest(dropCb)
})
It("downloads a message when packet are corrupted towards the client", func() {
const interval = 10 // corrupt every 10th packet (stochastically)
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionOutgoing {
numPackets.Add(1)
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
numCorrupted.Add(1)
return true
}
}
return false
}
runTest(dropCb)
})
})
})
Context("successful injection attacks", func() {
// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
// that arrives before the real reply can tear down the connection in multiple ways.
const rtt = 20 * time.Millisecond
runTest := func(proxyPort int) (closeFn func(), err error) {
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
)
return func() { clientTransport.Close() }, err
}
return rtt / 2
}
// fails immediately because client connection closes when it can't find compatible version
It("fails when a forged version negotiation packet is sent to client", func() {
done := make(chan struct{})
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
var transportErr *quic.TransportError
require.ErrorAs(t, err, &transportErr)
require.Equal(t, quic.ProtocolViolation, transportErr.ErrorCode)
require.Contains(t, transportErr.ErrorMessage, "received ACK for an unsent packet")
}
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
func runMITMTestSuccessful(t *testing.T, serverTransport, clientTransport *quic.Transport, delayCb quicproxy.DelayCallback) error {
t.Helper()
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
if hdr.Type != protocol.PacketTypeInitial {
return 0
}
// Create fake version negotiation packet with no supported versions
versions := []protocol.Version{}
packet := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
versions,
)
// Send the packet
_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
close(done)
}
return rtt / 2
}
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn()
Expect(err).To(HaveOccurred())
vnErr := &quic.VersionNegotiationError{}
Expect(errors.As(err, &vnErr)).To(BeTrue())
Eventually(done).Should(BeClosed())
})
// times out, because client doesn't accept subsequent real retry packets from server
// as it has already accepted a retry.
// TODO: determine behavior when server does not send Retry packets
It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
var initialPacketIntercepted bool
done := make(chan struct{})
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
defer GinkgoRecover()
defer close(done)
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial {
return 0
}
initialPacketIntercepted = true
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt / 2
}
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
Eventually(done).Should(BeClosed())
})
// times out, because client doesn't accept real retry packets from server because
// it has already accepted an initial.
// TODO: determine behavior when server does not send Retry packets
It("fails when a forged initial packet is sent to client", func() {
done := make(chan struct{})
var injected bool
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial || injected {
return 0
}
defer close(done)
injected = true
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, nil, protocol.PerspectiveServer, hdr.Version)
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt
}
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
Eventually(done).Should(BeClosed())
})
// client connection closes immediately on receiving ack for unsent packet
It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
done := make(chan struct{})
var injected bool
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial || injected {
return 0
}
defer close(done)
injected = true
// Fake Initial with ACK for packet 2 (unsent)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version)
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt
}
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn()
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
Eventually(done).Should(BeClosed())
})
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: delayCb,
})
})
require.NoError(t, err)
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
defer cancel()
_, err = clientTransport.Dial(
ctx,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(nil),
)
require.Error(t, err)
return err
}

View File

@@ -1,11 +1,13 @@
package self_test
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/quic-go/quic-go"
@@ -13,158 +15,188 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("DPLPMTUD", func() {
It("discovers the MTU", func() {
rtt := scaleDuration(5 * time.Millisecond)
const mtu = 1400
func TestInitialPacketSize(t *testing.T) {
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer server.Close()
client, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer client.Close()
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
InitialPacketSize: 1234,
DisablePathMTUDiscovery: true,
EnableDatagrams: true,
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
_, err = io.Copy(str, str)
Expect(err).ToNot(HaveOccurred())
str.Close()
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
quic.Dial(ctx, client, server.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{
InitialPacketSize: 1337,
}))
}()
var mx sync.Mutex
var maxPacketSizeServer int
var clientPacketSizes []int
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if len(packet) > mtu {
return true
}
mx.Lock()
defer mx.Unlock()
switch dir {
case quicproxy.DirectionIncoming:
clientPacketSizes = append(clientPacketSizes, len(packet))
case quicproxy.DirectionOutgoing:
if len(packet) > maxPacketSizeServer {
maxPacketSizeServer = len(packet)
}
}
return false
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
buf := make([]byte, 2000)
n, _, err := server.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, 1337, n)
// Make sure to use v4-only socket here.
// We can't reliably set the DF bit on dual-stack sockets on macOS before Sequoia (macOS 15).
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{Conn: udpConn}
defer tr.Close()
var mtus []logging.ByteCount
mtuTracer := &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, _ bool) {
mtus = append(mtus, mtu)
},
cancel()
<-done
}
func TestPathMTUDiscovery(t *testing.T) {
rtt := scaleDuration(5 * time.Millisecond)
const mtu = 1400
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
InitialPacketSize: 1234,
DisablePathMTUDiscovery: true,
EnableDatagrams: true,
}),
)
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
serverErrChan <- err
return
}
conn, err := tr.Dial(
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
InitialPacketSize: protocol.MinInitialPacketSize,
EnableDatagrams: true,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return mtuTracer
},
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer close(done)
defer GinkgoRecover()
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong))
}()
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
_, err = str.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred())
str.Close()
Eventually(done, 20*time.Second).Should(BeClosed())
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
str, err := conn.AcceptStream(context.Background())
if err != nil {
serverErrChan <- err
return
}
defer str.Close()
if _, err := io.Copy(str, str); err != nil {
serverErrChan <- err
return
}
}()
mx.Lock()
defer mx.Unlock()
Expect(mtus).ToNot(BeEmpty())
maxPacketSizeClient := int(mtus[len(mtus)-1])
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize)
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
Expect(initialMaxDatagramSize).To(BeNumerically(">=", protocol.MinInitialPacketSize-maxDiff))
Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff))
// MTU discovery was disabled on the server side
Expect(maxPacketSizeServer).To(Equal(1234))
var numPacketsLargerThanDiscoveredMTU int
for _, s := range clientPacketSizes {
if s > maxPacketSizeClient {
numPacketsLargerThanDiscoveredMTU++
var mx sync.Mutex
var maxPacketSizeServer int
var clientPacketSizes []int
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if len(packet) > mtu {
return true
}
mx.Lock()
defer mx.Unlock()
switch dir {
case quicproxy.DirectionIncoming:
clientPacketSizes = append(clientPacketSizes, len(packet))
case quicproxy.DirectionOutgoing:
if len(packet) > maxPacketSizeServer {
maxPacketSizeServer = len(packet)
}
}
return false
},
})
require.NoError(t, err)
defer proxy.Close()
// Make sure to use v4-only socket here.
// We can't reliably set the DF bit on dual-stack sockets on older versions of macOS (before Sequoia).
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer udpConn.Close()
tr := &quic.Transport{Conn: udpConn}
defer tr.Close()
var mtus []logging.ByteCount
conn, err := tr.Dial(
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
InitialPacketSize: protocol.MinInitialPacketSize,
EnableDatagrams: true,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, _ bool) { mtus = append(mtus, mtu) },
}
},
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
err = conn.SendDatagram(make([]byte, 2000))
require.Error(t, err)
var datagramErr *quic.DatagramTooLargeError
require.ErrorAs(t, err, &datagramErr)
initialMaxDatagramSize := datagramErr.MaxDatagramPayloadSize
str, err := conn.OpenStream()
require.NoError(t, err)
clientErrChan := make(chan error, 1)
go func() {
data, err := io.ReadAll(str)
if err != nil {
clientErrChan <- err
return
}
// The client shouldn't have sent any packets larger than the MTU it discovered,
// except for at most one MTU probe packet.
Expect(numPacketsLargerThanDiscoveredMTU).To(BeNumerically("<=", 1))
})
if !bytes.Equal(data, PRDataLong) {
clientErrChan <- fmt.Errorf("echoed data doesn't match: %x", data)
return
}
clientErrChan <- nil
}()
It("uses the initial packet size", func() {
c, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer c.Close()
_, err = str.Write(PRDataLong)
require.NoError(t, err)
str.Close()
cconn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer cconn.Close()
select {
case err := <-clientErrChan:
require.NoError(t, err)
case err := <-serverErrChan:
t.Fatalf("server error: %v", err)
case <-time.After(20 * time.Second):
t.Fatal("timeout")
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
quic.Dial(ctx, cconn, c.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{InitialPacketSize: 1337}))
}()
err = conn.SendDatagram(make([]byte, 2000))
require.Error(t, err)
require.ErrorAs(t, err, &datagramErr)
finalMaxDatagramSize := datagramErr.MaxDatagramPayloadSize
b := make([]byte, 2000)
n, _, err := c.ReadFrom(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(1337))
cancel()
Eventually(done).Should(BeClosed())
})
})
mx.Lock()
defer mx.Unlock()
require.NotEmpty(t, mtus)
maxPacketSizeClient := int(mtus[len(mtus)-1])
t.Logf("max client packet size: %d, MTU: %d", maxPacketSizeClient, mtu)
t.Logf("max datagram size: initial: %d, final: %d", initialMaxDatagramSize, finalMaxDatagramSize)
t.Logf("max server packet size: %d, MTU: %d", maxPacketSizeServer, mtu)
require.GreaterOrEqual(t, maxPacketSizeClient, mtu-25)
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
require.GreaterOrEqual(t, int(initialMaxDatagramSize), protocol.MinInitialPacketSize-maxDiff)
require.GreaterOrEqual(t, int(finalMaxDatagramSize), maxPacketSizeClient-maxDiff)
// MTU discovery was disabled on the server side
require.Equal(t, 1234, maxPacketSizeServer)
var numPacketsLargerThanDiscoveredMTU int
for _, s := range clientPacketSizes {
if s > maxPacketSizeClient {
numPacketsLargerThanDiscoveredMTU++
}
}
// The client shouldn't have sent any packets larger than the MTU it discovered,
// except for at most one MTU probe packet.
require.LessOrEqual(t, numPacketsLargerThanDiscoveredMTU, 1)
}

View File

@@ -1,296 +1,323 @@
package self_test
import (
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"runtime"
"sync/atomic"
"testing"
"time"
"golang.org/x/exp/rand"
"github.com/quic-go/quic-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Multiplexing", func() {
runServer := func(ln *quic.Listener) {
func runMultiplexTestServer(t *testing.T, ln *quic.Listener) {
t.Helper()
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return
}
str, err := conn.OpenUniStream()
require.NoError(t, err)
go func() {
defer GinkgoRecover()
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
}()
}
defer str.Close()
_, err = str.Write(PRData)
require.NoError(t, err)
}()
t.Cleanup(func() { conn.CloseWithError(0, "") })
}
}
func dialAndReceiveData(tr *quic.Transport, addr net.Addr) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := tr.Dial(ctx, addr, getTLSClientConfig(), getQuicConfig(nil))
if err != nil {
return fmt.Errorf("error dialing: %w", err)
}
str, err := conn.AcceptUniStream(ctx)
if err != nil {
return fmt.Errorf("error accepting stream: %w", err)
}
data, err := io.ReadAll(str)
if err != nil {
return fmt.Errorf("error reading data: %w", err)
}
if !bytes.Equal(data, PRData) {
return fmt.Errorf("data mismatch: got %q, expected %q", data, PRData)
}
return nil
}
func TestMultiplexesConnectionsToSameServer(t *testing.T) {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
go runMultiplexTestServer(t, server)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
defer tr.Close()
errChan1 := make(chan error, 1)
go func() { errChan1 <- dialAndReceiveData(tr, server.Addr()) }()
errChan2 := make(chan error, 1)
go func() { errChan2 <- dialAndReceiveData(tr, server.Addr()) }()
select {
case err := <-errChan1:
require.NoError(t, err, "error dialing server 1")
case <-time.After(5 * time.Second):
t.Error("timeout waiting for done1 to close")
}
select {
case err := <-errChan2:
require.NoError(t, err)
case <-time.After(5 * time.Second):
t.Error("timeout waiting for done2 to close")
}
}
func TestMultiplexingToDifferentServers(t *testing.T) {
server1, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server1.Close()
go runMultiplexTestServer(t, server1)
server2, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server2.Close()
go runMultiplexTestServer(t, server2)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
defer tr.Close()
errChan1 := make(chan error, 1)
go func() { errChan1 <- dialAndReceiveData(tr, server1.Addr()) }()
errChan2 := make(chan error, 1)
go func() { errChan2 <- dialAndReceiveData(tr, server2.Addr()) }()
select {
case err := <-errChan1:
require.NoError(t, err, "error dialing server 1")
case <-time.After(5 * time.Second):
t.Error("timeout waiting for done1 to close")
}
select {
case err := <-errChan2:
require.NoError(t, err, "error dialing server 2")
case <-time.After(5 * time.Second):
t.Error("timeout waiting for done2 to close")
}
}
func TestMultiplexingConnectToSelf(t *testing.T) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
go runMultiplexTestServer(t, server)
errChan := make(chan error, 1)
go func() { errChan <- dialAndReceiveData(tr, server.Addr()) }()
select {
case err := <-errChan:
require.NoError(t, err, "error dialing server")
case <-time.After(5 * time.Second):
t.Error("timeout waiting for connection to close")
}
}
func TestMultiplexingServerAndClientOnSameConn(t *testing.T) {
if runtime.GOOS == "linux" {
t.Skip("This test requires setting of iptables rules on Linux, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.")
}
dial := func(tr *quic.Transport, addr net.Addr) {
conn, err := tr.Dial(
context.Background(),
addr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
defer tr1.Close()
server1, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server1.Close()
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
defer tr2.Close()
server2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server2.Close()
done1 := make(chan struct{})
go func() {
defer close(done1)
dialAndReceiveData(tr2, server1.Addr())
}()
done2 := make(chan struct{})
go func() {
defer close(done2)
dialAndReceiveData(tr1, server2.Addr())
}()
select {
case <-done1:
case <-time.After(5 * time.Second):
t.Error("timeout waiting for done1 to close")
}
select {
case <-done2:
case <-time.After(time.Second):
t.Error("timeout waiting for done2 to close")
}
}
Context("multiplexing clients on the same conn", func() {
getListener := func() *quic.Listener {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
return ln
func TestMultiplexingNonQUICPackets(t *testing.T) {
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
defer tr1.Close()
addTracer(tr1)
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
defer tr2.Close()
addTracer(tr2)
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
type nonQUICPacket struct {
b []byte
addr net.Addr
err error
}
done := make(chan struct{})
var rcvdPackets []nonQUICPacket
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// start receiving non-QUIC packets
go func() {
defer close(done)
for {
b := make([]byte, 1024)
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
if errors.Is(err, context.Canceled) {
return
}
rcvdPackets = append(rcvdPackets, nonQUICPacket{b: b[:n], addr: addr, err: err})
}
}()
It("multiplexes connections to the same server", func() {
server := getListener()
runServer(server)
defer server.Close()
// send a non-QUIC packet every 100µs
const packetLen = 128
var sentPackets atomic.Int64
errChan := make(chan error, 1)
go func() {
ticker := time.NewTicker(time.Millisecond / 10)
defer ticker.Stop()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(tr, server.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(tr, server.Addr())
close(done2)
}()
timeout := 30 * time.Second
if debugLog() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
It("multiplexes connections to different servers", func() {
server1 := getListener()
runServer(server1)
defer server1.Close()
server2 := getListener()
runServer(server2)
defer server2.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(tr, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(tr, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if debugLog() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
})
Context("multiplexing server and client on the same conn", func() {
It("connects to itself", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runServer(server)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(tr, server.Addr())
close(done)
}()
timeout := 30 * time.Second
if debugLog() {
timeout = time.Minute
}
Eventually(done, timeout).Should(BeClosed())
})
// This test would require setting of iptables rules, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.
if runtime.GOOS != "linux" {
It("runs a server and client on the same conn", func() {
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server1, err := tr1.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runServer(server1)
defer server1.Close()
server2, err := tr2.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runServer(server2)
defer server2.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(tr2, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(tr1, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if debugLog() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
}
})
It("sends and receives non-QUIC packets", func() {
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
runServer(server)
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var sentPackets, rcvdPackets atomic.Int64
const packetLen = 128
// send a non-QUIC packet every 100µs
go func() {
defer GinkgoRecover()
ticker := time.NewTicker(time.Millisecond / 10)
defer ticker.Stop()
var wroteFirstPacket bool
for {
select {
case <-ticker.C:
case <-ctx.Done():
return
}
var wroteFirstPacket bool
for {
select {
case <-ticker.C:
b := make([]byte, packetLen)
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
// The first sendmsg call on a new UDP socket sometimes errors on Linux.
// It's not clear why this happens.
// See https://github.com/golang/go/issues/63322.
if err != nil && !wroteFirstPacket && isPermissionError(err) {
if err != nil && !wroteFirstPacket && runtime.GOOS == "linux" && isPermissionError(err) {
_, err = tr1.WriteTo(b, tr2.Conn.LocalAddr())
}
if ctx.Err() != nil { // ctx canceled while Read was executing
if err != nil {
errChan <- err
return
}
Expect(err).ToNot(HaveOccurred())
sentPackets.Add(1)
wroteFirstPacket = true
case <-ctx.Done():
return
}
}()
}
}()
// receive and count non-QUIC packets
go func() {
defer GinkgoRecover()
for {
b := make([]byte, 1024)
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
if err != nil {
Expect(err).To(MatchError(context.Canceled))
return
}
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
Expect(n).To(Equal(packetLen))
rcvdPackets.Add(1)
}
}()
dial(tr2, server.Addr())
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
})
})
conn, err := tr2.Dial(
context.Background(),
server.Addr(),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
serverConn, err := server.Accept(context.Background())
require.NoError(t, err)
serverStr, err := serverConn.OpenUniStream()
require.NoError(t, err)
go func() {
defer serverStr.Close()
_, _ = serverStr.Write(PRData)
}()
str, err := conn.AcceptUniStream(context.Background())
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRData, data)
// stop sending non-QUIC packets
cancel()
select {
case err := <-errChan:
t.Fatalf("error sending non-QUIC packets: %v", err)
case <-done:
}
sent := int(sentPackets.Load())
require.Greater(t, sent, 10, "not enough non-QUIC packets sent: %d", sent)
rcvd := len(rcvdPackets)
minExpected := sent * 4 / 5
require.GreaterOrEqual(t, rcvd, minExpected, "not enough packets received. got: %d, expected at least: %d", rcvd, minExpected)
for _, p := range rcvdPackets {
require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address")
require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length")
require.NoError(t, p.err, "error receiving non-QUIC packet")
}
}

View File

@@ -4,117 +4,120 @@ import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Packetization", func() {
// In this test, the client sends 100 small messages. The server echoes these messages.
// This means that every endpoint will send 100 ack-eliciting packets in short succession.
// This test then tests that no more than 110 packets are sent in every direction, making sure that ACK are bundled.
It("bundles ACKs", func() {
const numMsg = 100
func TestACKBundling(t *testing.T) {
const numMsg = 100
serverCounter, serverTracer := newPacketTracer()
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(serverTracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
serverCounter, serverTracer := newPacketTracer()
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(serverTracer),
}),
)
require.NoError(t, err)
defer server.Close()
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr,
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr,
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
return 5 * time.Millisecond
},
})
require.NoError(t, err)
defer proxy.Close()
clientCounter, clientTracer := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(clientTracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
clientCounter, clientTracer := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(clientTracer),
}),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 1)
// Echo every byte received from the client.
for {
if _, err := str.Read(b); err != nil {
break
}
_, err = str.Write(b)
Expect(err).ToNot(HaveOccurred())
}
}()
str, err := conn.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 1)
// Send numMsg 1-byte messages.
for i := 0; i < numMsg; i++ {
_, err = str.Write([]byte{uint8(i)})
Expect(err).ToNot(HaveOccurred())
_, err = str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b[0]).To(Equal(uint8(i)))
}
Expect(conn.CloseWithError(0, "")).To(Succeed())
countBundledPackets := func(packets []shortHeaderPacket) (numBundled int) {
for _, p := range packets {
var hasAck, hasStreamFrame bool
for _, f := range p.frames {
switch f.(type) {
case *logging.AckFrame:
hasAck = true
case *logging.StreamFrame:
hasStreamFrame = true
}
}
if hasAck && hasStreamFrame {
numBundled++
}
}
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := server.Accept(context.Background())
if err != nil {
serverErrChan <- fmt.Errorf("accept failed: %w", err)
return
}
str, err := conn.AcceptStream(context.Background())
if err != nil {
serverErrChan <- fmt.Errorf("accept stream failed: %w", err)
return
}
b := make([]byte, 1)
// Echo every byte received from the client.
for {
if _, err := str.Read(b); err != nil {
break
}
_, err = str.Write(b)
if err != nil {
serverErrChan <- fmt.Errorf("write failed: %w", err)
return
}
}
}()
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg)
fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg)
Expect(numBundledIncoming).To(And(
BeNumerically("<=", numMsg),
BeNumerically(">", numMsg*9/10),
))
Expect(numBundledOutgoing).To(And(
BeNumerically("<=", numMsg),
BeNumerically(">", numMsg*9/10),
))
})
})
str, err := conn.OpenStreamSync(context.Background())
require.NoError(t, err)
b := make([]byte, 1)
// Send numMsg 1-byte messages.
for i := 0; i < numMsg; i++ {
_, err = str.Write([]byte{uint8(i)})
require.NoError(t, err)
_, err = str.Read(b)
require.NoError(t, err)
require.Equal(t, uint8(i), b[0])
}
require.NoError(t, conn.CloseWithError(0, ""))
require.NoError(t, <-serverErrChan)
countBundledPackets := func(packets []shortHeaderPacket) (numBundled int) {
for _, p := range packets {
var hasAck, hasStreamFrame bool
for _, f := range p.frames {
switch f.(type) {
case *logging.AckFrame:
hasAck = true
case *logging.StreamFrame:
hasStreamFrame = true
}
}
if hasAck && hasStreamFrame {
numBundled++
}
}
return
}
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
t.Logf("bundled incoming packets: %d / %d", numBundledIncoming, numMsg)
t.Logf("bundled outgoing packets: %d / %d", numBundledOutgoing, numMsg)
require.LessOrEqual(t, numBundledIncoming, numMsg)
require.Greater(t, numBundledIncoming, numMsg*9/10)
require.LessOrEqual(t, numBundledOutgoing, numMsg)
require.Greater(t, numBundledOutgoing, numMsg*9/10)
}

View File

@@ -5,87 +5,77 @@ import (
"os"
"path"
"regexp"
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/qlog"
"github.com/stretchr/testify/require"
)
var _ = Describe("qlog dir tests", Serial, func() {
var originalQlogDirValue string
var tempTestDirPath string
BeforeEach(func() {
originalQlogDirValue = os.Getenv("QLOGDIR")
var err error
tempTestDirPath, err = os.MkdirTemp("", "temp_test_dir")
Expect(err).ToNot(HaveOccurred())
func TestQlogDirEnvironmentVariable(t *testing.T) {
originalQlogDirValue := os.Getenv("QLOGDIR")
tempTestDirPath, err := os.MkdirTemp("", "temp_test_dir")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, os.Setenv("QLOGDIR", originalQlogDirValue))
require.NoError(t, os.RemoveAll(tempTestDirPath))
})
qlogDir := path.Join(tempTestDirPath, "qlogs")
require.NoError(t, os.Setenv("QLOGDIR", qlogDir))
AfterEach(func() {
err := os.Setenv("QLOGDIR", originalQlogDirValue)
Expect(err).ToNot(HaveOccurred())
err = os.RemoveAll(tempTestDirPath)
Expect(err).ToNot(HaveOccurred())
})
serverStopped := make(chan struct{})
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
&quic.Config{
Tracer: qlog.DefaultConnectionTracer,
},
)
require.NoError(t, err)
handshake := func() {
serverStopped := make(chan struct{})
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
&quic.Config{
Tracer: qlog.DefaultConnectionTracer,
},
)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer close(serverStopped)
for {
if _, err := server.Accept(context.Background()); err != nil {
return
}
go func() {
defer close(serverStopped)
for {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}()
}
}()
conn, err := quic.DialAddr(
context.Background(),
server.Addr().String(),
getTLSClientConfig(),
&quic.Config{
Tracer: qlog.DefaultConnectionTracer,
},
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
server.Close()
<-serverStopped
conn, err := quic.DialAddr(
context.Background(),
server.Addr().String(),
getTLSClientConfig(),
&quic.Config{
Tracer: qlog.DefaultConnectionTracer,
},
)
require.NoError(t, err)
conn.CloseWithError(0, "")
server.Close()
<-serverStopped
_, err = os.Stat(tempTestDirPath)
qlogDirCreated := !os.IsNotExist(err)
require.True(t, qlogDirCreated)
childs, err := os.ReadDir(qlogDir)
require.NoError(t, err)
require.Len(t, childs, 2)
odcids := make([]string, 0, 2)
vantagePoints := make([]string, 0, 2)
qlogFileNameRegexp := regexp.MustCompile(`^([0-f]+)_(client|server).sqlog$`)
for _, child := range childs {
matches := qlogFileNameRegexp.FindStringSubmatch(child.Name())
require.Len(t, matches, 3)
odcids = append(odcids, matches[1])
vantagePoints = append(vantagePoints, matches[2])
}
It("environment variable is set", func() {
qlogDir := path.Join(tempTestDirPath, "qlogs")
err := os.Setenv("QLOGDIR", qlogDir)
Expect(err).ToNot(HaveOccurred())
handshake()
_, err = os.Stat(tempTestDirPath)
qlogDirCreated := !os.IsNotExist(err)
Expect(qlogDirCreated).To(BeTrue())
childs, err := os.ReadDir(qlogDir)
Expect(err).ToNot(HaveOccurred())
Expect(len(childs)).To(Equal(2))
odcids := make([]string, 0)
vantagePoints := make([]string, 0)
qlogFileNameRegexp := regexp.MustCompile(`^([0-f]+)_(client|server).sqlog$`)
for _, child := range childs {
matches := qlogFileNameRegexp.FindStringSubmatch(child.Name())
Expect(matches).To(HaveLen(3))
odcids = append(odcids, matches[1])
vantagePoints = append(vantagePoints, matches[2])
}
Expect(odcids[0]).To(Equal(odcids[1]))
Expect(vantagePoints).To(ContainElements("client", "server"))
})
})
require.Equal(t, odcids[0], odcids[1])
require.Contains(t, vantagePoints, "client")
require.Contains(t, vantagePoints, "server")
}

View File

@@ -5,19 +5,18 @@ import (
"crypto/tls"
"fmt"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
type clientSessionCache struct {
cache tls.ClientSessionCache
gets chan<- string
puts chan<- string
gets chan<- string
puts chan<- string
}
func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
@@ -45,94 +44,18 @@ func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState)
}
}
var _ = Describe("TLS session resumption", func() {
It("uses session resumption", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
conn2.CloseWithError(0, "")
func TestTLSSessionResumption(t *testing.T) {
t.Run("uses session resumption", func(t *testing.T) {
handshakeWithSessionResumption(t, getTLSConfig(), true)
})
It("doesn't use session resumption, if the config disables it", func() {
t.Run("disabled in tls.Config", func(t *testing.T) {
sConf := getTLSConfig()
sConf.SessionTicketsDisabled = true
server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
Consistently(puts).ShouldNot(Receive())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
handshakeWithSessionResumption(t, sConf, false)
})
It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() {
t.Run("disabled in tls.Config.GetConfigForClient", func(t *testing.T) {
sConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
conf := getTLSConfig()
@@ -140,45 +63,78 @@ var _ = Describe("TLS session resumption", func() {
return conf, nil
},
}
server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn1.CloseWithError(0, "")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
handshakeWithSessionResumption(t, sConf, false)
})
})
}
func handshakeWithSessionResumption(t *testing.T, serverTLSConf *tls.Config, expectSessionTicket bool) {
server, err := quic.ListenAddr("localhost:0", serverTLSConf, getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
// first connection - doesn't use resumption
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn1.CloseWithError(0, "")
require.False(t, conn1.ConnectionState().TLS.DidResume)
var sessionKey string
select {
case sessionKey = <-puts:
if !expectSessionTicket {
t.Fatal("unexpected session ticket")
}
case <-time.After(scaleDuration(50 * time.Millisecond)):
if expectSessionTicket {
t.Fatal("timeout waiting for session ticket")
}
}
serverConn, err := server.Accept(context.Background())
require.NoError(t, err)
require.False(t, serverConn.ConnectionState().TLS.DidResume)
// second connection - will use resumption, if enabled
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn2.CloseWithError(0, "")
select {
case k := <-gets:
if expectSessionTicket {
// we can only perform this check if we got a session ticket before
require.Equal(t, sessionKey, k)
}
case <-time.After(scaleDuration(50 * time.Millisecond)):
if expectSessionTicket {
t.Fatal("timeout waiting for retrieval of session ticket")
}
}
serverConn, err = server.Accept(context.Background())
require.NoError(t, err)
if expectSessionTicket {
require.True(t, conn2.ConnectionState().TLS.DidResume)
require.True(t, serverConn.ConnectionState().TLS.DidResume)
} else {
require.False(t, conn2.ConnectionState().TLS.DidResume)
require.False(t, serverConn.ConnectionState().TLS.DidResume)
}
}

View File

@@ -5,109 +5,134 @@ import (
"fmt"
"io"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("non-zero RTT", func() {
runServer := func() *quic.Listener {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
func runServerForRTTTest(t *testing.T) (net.Addr, <-chan error) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
errChan := make(chan error, 1)
go func() {
defer close(errChan)
for {
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
if err != nil {
errChan <- fmt.Errorf("accept error: %w", err)
return
}
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
if err != nil {
errChan <- fmt.Errorf("open stream error: %w", err)
return
}
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
if err != nil {
errChan <- fmt.Errorf("write error: %w", err)
return
}
str.Close()
}()
return ln
}
}
}()
downloadFile := func(port int) {
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
conn.CloseWithError(0, "")
}
return ln.Addr(), errChan
}
for _, r := range [...]time.Duration{
func TestDownloadWithFixedRTT(t *testing.T) {
addr, errChan := runServerForRTTTest(t)
for _, rtt := range []time.Duration{
10 * time.Millisecond,
50 * time.Millisecond,
100 * time.Millisecond,
200 * time.Millisecond,
250 * time.Millisecond,
} {
rtt := r
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
ln := runServer()
defer ln.Close()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return rtt / 2
},
t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
t.Cleanup(func() {
select {
case err := <-errChan:
t.Errorf("server error: %v", err)
default:
}
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
require.NoError(t, err)
t.Cleanup(func() { proxy.Close() })
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := quic.DialAddr(
context.Background(),
ctx,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.AcceptStream(ctx)
require.NoError(t, err)
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
conn.CloseWithError(0, "")
require.NoError(t, err)
require.Equal(t, PRData, data)
})
}
}
for _, r := range [...]time.Duration{
10 * time.Millisecond,
40 * time.Millisecond,
func TestDownloadWithReordering(t *testing.T) {
addr, errChan := runServerForRTTTest(t)
for _, rtt := range []time.Duration{
5 * time.Millisecond,
30 * time.Millisecond,
} {
rtt := r
t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
t.Cleanup(func() {
select {
case err := <-errChan:
t.Errorf("server error: %v", err)
default:
}
})
It(fmt.Sprintf("downloads a message with %s RTT, with reordering", rtt), func() {
ln := runServer()
defer ln.Close()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
return randomDuration(rtt/2, rtt*3/2) / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
require.NoError(t, err)
t.Cleanup(func() { proxy.Close() })
downloadFile(proxy.LocalPort())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := quic.DialAddr(
ctx,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer conn.CloseWithError(0, "")
str, err := conn.AcceptStream(ctx)
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, PRData, data)
})
}
})
}

View File

@@ -8,24 +8,19 @@ import (
"flag"
"fmt"
"io"
"log"
"math/rand"
"os"
"runtime/pprof"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/integrationtests/tools"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const alpn = tools.ALPN
@@ -53,56 +48,17 @@ func GeneratePRData(l int) []byte {
return res
}
const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
type syncedBuffer struct {
mutex sync.Mutex
*bytes.Buffer
}
func (b *syncedBuffer) Write(p []byte) (int, error) {
b.mutex.Lock()
n, err := b.Buffer.Write(p)
b.mutex.Unlock()
return n, err
}
func (b *syncedBuffer) Bytes() []byte {
b.mutex.Lock()
p := b.Buffer.Bytes()
b.mutex.Unlock()
return p
}
func (b *syncedBuffer) Reset() {
b.mutex.Lock()
b.Buffer.Reset()
b.mutex.Unlock()
}
var (
logFileName string // the log file set in the ginkgo flags
logBufOnce sync.Once
logBuf *syncedBuffer
versionParam string
version quic.Version
enableQlog bool
version quic.Version
tlsConfig *tls.Config
tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config
tlsClientConfigWithoutServerName *tls.Config
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
flag.StringVar(&versionParam, "version", "1", "QUIC version")
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
ca, caPrivateKey, err := tools.GenerateCA()
if err != nil {
panic(err)
@@ -137,31 +93,9 @@ func init() {
}
}
var _ = BeforeSuite(func() {
switch versionParam {
case "1":
version = quic.Version1
case "2":
version = quic.Version2
default:
Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
}
fmt.Printf("Using QUIC version: %s\n", version)
protocol.SupportedVersions = []quic.Version{version}
})
func getTLSConfig() *tls.Config {
return tlsConfig.Clone()
}
func getTLSConfigWithLongCertChain() *tls.Config {
return tlsConfigLongChain.Clone()
}
func getTLSClientConfig() *tls.Config {
return tlsClientConfig.Clone()
}
func getTLSConfig() *tls.Config { return tlsConfig.Clone() }
func getTLSConfigWithLongCertChain() *tls.Config { return tlsConfigLongChain.Clone() }
func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() }
func getTLSClientConfigWithoutServerName() *tls.Config {
return tlsClientConfigWithoutServerName.Clone()
}
@@ -178,7 +112,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
tools.NewQlogConnectionTracer(os.Stdout)(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
@@ -188,7 +122,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
tr := origTracer(ctx, p, connID)
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID)
qlogger := tools.NewQlogConnectionTracer(os.Stdout)(ctx, p, connID)
if tr == nil {
return qlogger
}
@@ -203,7 +137,7 @@ func addTracer(tr *quic.Transport) {
}
if tr.Tracer == nil {
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
tools.QlogTracer(os.Stdout),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.Tracer{},
)
@@ -211,23 +145,11 @@ func addTracer(tr *quic.Transport) {
}
origTracer := tr.Tracer
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
tools.QlogTracer(os.Stdout),
origTracer,
)
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if debugLog() {
logBufOnce.Do(func() {
logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
})
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
log.SetOutput(logBuf)
}
})
func areHandshakesRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
@@ -240,22 +162,36 @@ func areTransportsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
var _ = AfterEach(func() {
Expect(areHandshakesRunning()).To(BeFalse())
Eventually(areTransportsRunning).Should(BeFalse())
func TestMain(m *testing.M) {
var versionParam string
flag.StringVar(&versionParam, "version", "1", "QUIC version")
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
flag.Parse()
if debugLog() {
logFile, err := os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
logFile.Write(logBuf.Bytes())
logFile.Close()
logBuf.Reset()
switch versionParam {
case "1":
version = quic.Version1
case "2":
version = quic.Version2
default:
fmt.Printf("unknown QUIC version: %s\n", versionParam)
os.Exit(1)
}
})
fmt.Printf("using QUIC version: %s\n", version)
// Debug says if this test is being logged
func debugLog() bool {
return len(logFileName) > 0
status := m.Run()
if status != 0 {
os.Exit(status)
}
if areHandshakesRunning() {
fmt.Println("stray handshake goroutines found")
os.Exit(1)
}
if areTransportsRunning() {
fmt.Println("stray transport goroutines found")
os.Exit(1)
}
os.Exit(status)
}
func scaleDuration(d time.Duration) time.Duration {
@@ -301,6 +237,17 @@ func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
return t.rcvdLongHdr
}
func (t *packetCounter) getRcvd0RTTPacketNumbers() []protocol.PacketNumber {
packets := t.getRcvdLongHeaderPackets()
var zeroRTTPackets []protocol.PacketNumber
for _, p := range packets {
if p.hdr.Type == protocol.PacketType0RTT {
zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
}
}
return zeroRTTPackets
}
func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
<-t.closed
return t.rcvdShortHdr
@@ -345,7 +292,25 @@ func (r *readerWithTimeout) Read(p []byte) (n int, err error) {
}
}
func TestSelf(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Self integration tests")
func randomDuration(min, max time.Duration) time.Duration {
return min + time.Duration(rand.Int63n(int64(max-min)))
}
// contains0RTTPacket says if a packet contains a 0-RTT long header packet.
// It correctly handles coalesced packets.
func contains0RTTPacket(data []byte) bool {
for len(data) > 0 {
if !wire.IsLongHeaderPacket(data[0]) {
return false
}
hdr, _, rest, err := wire.ParsePacket(data)
if err != nil {
return false
}
if hdr.Type == protocol.PacketType0RTT {
return true
}
data = rest
}
return false
}

View File

@@ -6,121 +6,119 @@ import (
"fmt"
"net"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Stateless Resets", func() {
connIDLens := []int{0, 10}
func TestStatelessResets(t *testing.T) {
t.Run("0 byte connection IDs", func(t *testing.T) {
testStatelessReset(t, 0)
})
t.Run("10 byte connection IDs", func(t *testing.T) {
testStatelessReset(t, 10)
})
}
for i := range connIDLens {
connIDLen := connIDLens[i]
func testStatelessReset(t *testing.T, connIDLen int) {
var statelessResetKey quic.StatelessResetKey
rand.Read(statelessResetKey[:])
It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() {
var statelessResetKey quic.StatelessResetKey
rand.Read(statelessResetKey[:])
c, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
closeServer := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
<-closeServer
Expect(ln.Close()).To(Succeed())
Expect(tr.Close()).To(Succeed())
}()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
cl := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer cl.Close()
conn, err := cl.Dial(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data := make([]byte, 6)
_, err = str.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
// make sure that the CONNECTION_CLOSE is dropped
drop.Store(true)
close(closeServer)
time.Sleep(100 * time.Millisecond)
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
drop.Store(false)
acceptStopped := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln2.Accept(context.Background())
Expect(err).To(HaveOccurred())
close(acceptStopped)
}()
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
// If the client already sent another packet, it might already have received a packet.
_, serr := str.Write([]byte("Lorem ipsum dolor sit amet."))
if serr == nil {
_, serr = str.Read([]byte{0})
}
Expect(serr).To(HaveOccurred())
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
Expect(ln2.Close()).To(Succeed())
Eventually(acceptStopped).Should(BeClosed())
})
c, err := net.ListenUDP("udp", nil)
require.NoError(t, err)
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
ConnectionIDLength: connIDLen,
}
})
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
serverPort := ln.Addr().(*net.UDPAddr).Port
serverErr := make(chan error, 1)
go func() {
conn, err := ln.Accept(context.Background())
if err != nil {
serverErr <- err
return
}
str, err := conn.OpenStream()
if err != nil {
serverErr <- err
return
}
_, err = str.Write([]byte("foobar"))
if err != nil {
serverErr <- err
return
}
close(serverErr)
}()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
require.NoError(t, err)
defer proxy.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
require.NoError(t, err)
udpConn, err := net.ListenUDP("udp", addr)
require.NoError(t, err)
defer udpConn.Close()
cl := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer cl.Close()
conn, err := cl.Dial(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
str, err := conn.AcceptStream(context.Background())
require.NoError(t, err)
data := make([]byte, 6)
_, err = str.Read(data)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), data)
// make sure that the CONNECTION_CLOSE is dropped
drop.Store(true)
require.NoError(t, ln.Close())
require.NoError(t, tr.Close())
require.NoError(t, <-serverErr)
time.Sleep(100 * time.Millisecond)
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
drop.Store(false)
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
// If the client already sent another packet, it might already have received a packet.
_, serr := str.Write([]byte("Lorem ipsum dolor sit amet."))
if serr == nil {
_, serr = str.Read([]byte{0})
}
require.Error(t, serr)
require.IsType(t, &quic.StatelessResetError{}, serr)
require.NoError(t, ln2.Close())
}

View File

@@ -1,156 +1,321 @@
package self_test
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/quic-go/quic-go"
"golang.org/x/sync/errgroup"
"github.com/stretchr/testify/require"
)
var _ = Describe("Bidirectional streams", func() {
const numStreams = 300
func TestBidirectionalStreamMultiplexing(t *testing.T) {
const numStreams = 75
var (
server *quic.Listener
serverAddr string
)
BeforeEach(func() {
var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
runSendingPeer := func(conn quic.Connection) {
var wg sync.WaitGroup
wg.Add(numStreams)
runSendingPeer := func(conn quic.Connection) error {
g := new(errgroup.Group)
for i := 0; i < numStreams; i++ {
str, err := conn.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
data := GeneratePRData(25 * i)
go func() {
defer GinkgoRecover()
_, err := str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
go func() {
defer GinkgoRecover()
defer wg.Done()
if err != nil {
return err
}
data := GeneratePRData(50 * i)
g.Go(func() error {
if _, err := str.Write(data); err != nil {
return err
}
return str.Close()
})
g.Go(func() error {
dataRead, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(dataRead).To(Equal(data))
}()
if err != nil {
return err
}
if !bytes.Equal(dataRead, data) {
return fmt.Errorf("data mismatch: %q != %q", dataRead, data)
}
return nil
})
}
wg.Wait()
return g.Wait()
}
runReceivingPeer := func(conn quic.Connection) {
var wg sync.WaitGroup
wg.Add(numStreams)
runReceivingPeer := func(conn quic.Connection) error {
g := new(errgroup.Group)
for i := 0; i < numStreams; i++ {
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
if err != nil {
return err
}
g.Go(func() error {
// shouldn't use io.Copy here
// we should read from the stream as early as possible, to free flow control credit
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
if err != nil {
return err
}
if _, err := str.Write(data); err != nil {
return err
}
return str.Close()
})
}
wg.Wait()
return g.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
var conn quic.Connection
go func() {
defer GinkgoRecover()
var err error
conn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(conn)
}()
t.Run("client -> server", func(t *testing.T) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
MaxIncomingStreams: 10,
InitialStreamReceiveWindow: 10000,
InitialConnectionReceiveWindow: 5000,
}),
)
require.NoError(t, err)
defer ln.Close()
client, err := quic.DialAddr(
context.Background(),
serverAddr,
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}),
)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
require.NoError(t, err)
conn, err := ln.Accept(context.Background())
require.NoError(t, err)
errChan := make(chan error, 1)
go func() { errChan <- runReceivingPeer(conn) }()
require.NoError(t, runSendingPeer(client))
client.CloseWithError(0, "")
<-conn.Context().Done()
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
select {
case <-conn.Context().Done():
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
t.Run("client <-> server", func(t *testing.T) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
MaxIncomingStreams: 30,
InitialStreamReceiveWindow: 25000,
InitialConnectionReceiveWindow: 50000,
}),
)
require.NoError(t, err)
defer ln.Close()
client, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}),
)
require.NoError(t, err)
conn, err := ln.Accept(context.Background())
require.NoError(t, err)
errChan1 := make(chan error, 1)
errChan2 := make(chan error, 1)
errChan3 := make(chan error, 1)
errChan4 := make(chan error, 1)
go func() { errChan1 <- runReceivingPeer(conn) }()
go func() { errChan2 <- runSendingPeer(conn) }()
go func() { errChan3 <- runReceivingPeer(client) }()
go func() { errChan4 <- runSendingPeer(client) }()
for _, ch := range []chan error{errChan1, errChan2, errChan3, errChan4} {
select {
case err := <-ch:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
}
client.CloseWithError(0, "")
select {
case <-conn.Context().Done():
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
}
func TestUnidirectionalStreams(t *testing.T) {
const numStreams = 500
dataForStream := func(id uint64) []byte { return GeneratePRData(10 * int(id)) }
runSendingPeer := func(conn quic.Connection) error {
g := new(errgroup.Group)
for i := 0; i < numStreams; i++ {
str, err := conn.OpenUniStreamSync(context.Background())
if err != nil {
return err
}
g.Go(func() error {
if _, err := str.Write(dataForStream(uint64(str.StreamID()))); err != nil {
return err
}
return str.Close()
})
}
return g.Wait()
}
runReceivingPeer := func(conn quic.Connection) error {
g := new(errgroup.Group)
for i := 0; i < numStreams; i++ {
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
return err
}
g.Go(func() error {
data, err := io.ReadAll(str)
if err != nil {
return err
}
if !bytes.Equal(data, dataForStream(uint64(str.StreamID()))) {
return fmt.Errorf("data mismatch")
}
return nil
})
}
return g.Wait()
}
t.Run("client -> server", func(t *testing.T) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer ln.Close()
client, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
errChan := make(chan error, 1)
go func() { errChan <- runSendingPeer(client) }()
require.NoError(t, runReceivingPeer(serverConn))
serverConn.CloseWithError(0, "")
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("server -> client", func(t *testing.T) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer ln.Close()
client, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
serverConn, err := ln.Accept(context.Background())
require.NoError(t, err)
errChan := make(chan error, 1)
go func() { errChan <- runSendingPeer(serverConn) }()
require.NoError(t, runReceivingPeer(client))
client.CloseWithError(0, "")
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("client <-> server", func(t *testing.T) {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer ln.Close()
errChan1 := make(chan error, 1)
errChan2 := make(chan error, 1)
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(conn)
conn.CloseWithError(0, "")
conn, err := ln.Accept(context.Background())
if err != nil {
errChan1 <- err
errChan2 <- err
return
}
errChan1 <- runReceivingPeer(conn)
errChan2 <- runSendingPeer(conn)
}()
client, err := quic.DialAddr(
context.Background(),
serverAddr,
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
Eventually(client.Context().Done()).Should(BeClosed())
})
require.NoError(t, err)
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
errChan3 := make(chan error, 1)
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(conn)
close(done)
}()
runSendingPeer(conn)
<-done
close(done1)
errChan3 <- runSendingPeer(client)
}()
require.NoError(t, runReceivingPeer(client))
client, err := quic.DialAddr(
context.Background(),
serverAddr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
for _, ch := range []chan error{errChan1, errChan2, errChan3} {
select {
case err := <-ch:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
}
client.CloseWithError(0, "")
})
})
}

View File

@@ -1,22 +1,366 @@
package self_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
mrand "math/rand"
"net"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
func requireIdleTimeoutError(t *testing.T, err error) {
t.Helper()
require.Error(t, err)
var idleTimeoutErr *quic.IdleTimeoutError
require.ErrorAs(t, err, &idleTimeoutErr)
require.True(t, idleTimeoutErr.Timeout())
var nerr net.Error
require.True(t, errors.As(err, &nerr))
require.True(t, nerr.Timeout())
}
func TestHandshakeIdleTimeout(t *testing.T) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
)
errChan <- err
}()
select {
case err := <-errChan:
requireIdleTimeoutError(t, err)
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for dial error")
}
}
func TestHandshakeTimeoutContext(t *testing.T) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
errChan := make(chan error)
go func() {
_, err := quic.DialAddr(
ctx,
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for dial error")
}
}
func TestHandshakeTimeout0RTTContext(t *testing.T) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
errChan := make(chan error)
go func() {
_, err := quic.DialAddrEarly(
ctx,
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for dial error")
}
}
func TestIdleTimeout(t *testing.T) {
idleTimeout := scaleDuration(200 * time.Millisecond)
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
require.NoError(t, err)
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
)
require.NoError(t, err)
serverConn, err := server.Accept(context.Background())
require.NoError(t, err)
str, err := serverConn.OpenStream()
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
strIn, err := conn.AcceptStream(context.Background())
require.NoError(t, err)
strOut, err := conn.OpenStream()
require.NoError(t, err)
_, err = strIn.Read(make([]byte, 6))
require.NoError(t, err)
drop.Store(true)
time.Sleep(2 * idleTimeout)
_, err = strIn.Write([]byte("test"))
requireIdleTimeoutError(t, err)
_, err = strIn.Read([]byte{0})
requireIdleTimeoutError(t, err)
_, err = strOut.Write([]byte("test"))
requireIdleTimeoutError(t, err)
_, err = strOut.Read([]byte{0})
requireIdleTimeoutError(t, err)
_, err = conn.OpenStream()
requireIdleTimeoutError(t, err)
_, err = conn.OpenUniStream()
requireIdleTimeoutError(t, err)
_, err = conn.AcceptStream(context.Background())
requireIdleTimeoutError(t, err)
_, err = conn.AcceptUniStream(context.Background())
requireIdleTimeoutError(t, err)
}
func TestKeepAlive(t *testing.T) {
idleTimeout := scaleDuration(150 * time.Millisecond)
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
})
require.NoError(t, err)
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
KeepAlivePeriod: idleTimeout / 2,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
// wait longer than the idle timeout
time.Sleep(3 * idleTimeout)
str, err := conn.OpenUniStream()
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
// verify connection is still alive
select {
case <-serverConn.Context().Done():
t.Fatal("server connection closed unexpectedly")
default:
}
// idle timeout will still kick in if PINGs are dropped
drop.Store(true)
time.Sleep(2 * idleTimeout)
_, err = str.Write([]byte("foobar"))
var nerr net.Error
require.True(t, errors.As(err, &nerr))
require.True(t, nerr.Timeout())
// can't rely on the server connection closing, since we impose a minimum idle timeout of 5s,
// see https://github.com/quic-go/quic-go/issues/4751
serverConn.CloseWithError(0, "")
}
func TestTimeoutAfterInactivity(t *testing.T) {
idleTimeout := scaleDuration(150 * time.Millisecond)
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
defer server.Close()
counter, tr := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
Tracer: newTracer(tr),
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout)
defer cancel()
_, err = conn.AcceptStream(ctx)
requireIdleTimeoutError(t, err)
var lastAckElicitingPacketSentAt time.Time
for _, p := range counter.getSentShortHeaderPackets() {
var hasAckElicitingFrame bool
for _, f := range p.frames {
if _, ok := f.(*logging.AckFrame); ok {
continue
}
hasAckElicitingFrame = true
break
}
if hasAckElicitingFrame {
lastAckElicitingPacketSentAt = p.time
}
}
rcvdPackets := counter.getRcvdShortHeaderPackets()
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
// This is ok since we're dealing with a lossless connection here,
// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt)
timeSinceLastRcvd := time.Since(lastPacketRcvdAt)
maxDuration := max(timeSinceLastAckEliciting, timeSinceLastRcvd)
require.GreaterOrEqual(t, maxDuration, idleTimeout)
require.Less(t, maxDuration, idleTimeout*6/5)
select {
case <-serverConn.Context().Done():
t.Fatal("server connection closed unexpectedly")
default:
}
serverConn.CloseWithError(0, "")
}
func TestTimeoutAfterSendingPacket(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("This test is flaky on Windows due to low timer precision.")
}
idleTimeout := scaleDuration(150 * time.Millisecond)
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(d quicproxy.Direction, _ []byte) bool { return d == quicproxy.DirectionOutgoing && drop.Load() },
})
require.NoError(t, err)
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
// wait half the idle timeout, then send a packet
time.Sleep(idleTimeout / 2)
drop.Store(true)
str, err := conn.OpenUniStream()
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
// now make sure that the idle timeout is based on this packet
startTime := time.Now()
ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout)
defer cancel()
_, err = conn.AcceptStream(ctx)
requireIdleTimeoutError(t, err)
dur := time.Since(startTime)
require.GreaterOrEqual(t, dur, idleTimeout)
require.Less(t, dur, idleTimeout*12/10)
// Verify server connection is still open
select {
case <-serverConn.Context().Done():
t.Fatal("server connection closed unexpectedly")
default:
}
serverConn.CloseWithError(0, "")
}
type faultyConn struct {
net.PacketConn
@@ -41,490 +385,136 @@ func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return 0, io.ErrClosedPipe
}
var _ = Describe("Timeout tests", func() {
checkTimeoutError := func(err error) {
ExpectWithOffset(1, err).To(MatchError(&quic.IdleTimeoutError{}))
nerr, ok := err.(net.Error)
ExpectWithOffset(1, ok).To(BeTrue())
ExpectWithOffset(1, nerr.Timeout()).To(BeTrue())
func TestFaultyPacketConn(t *testing.T) {
t.Run("client", func(t *testing.T) {
testFaultyPacketConn(t, protocol.PerspectiveClient)
})
t.Run("server", func(t *testing.T) {
testFaultyPacketConn(t, protocol.PerspectiveServer)
})
}
func testFaultyPacketConn(t *testing.T, pers protocol.Perspective) {
handshakeTimeout := scaleDuration(100 * time.Millisecond)
runServer := func(ln *quic.Listener) error {
conn, err := ln.Accept(context.Background())
if err != nil {
return err
}
str, err := conn.OpenUniStream()
if err != nil {
return err
}
defer str.Close()
_, err = str.Write(PRData)
return err
}
It("returns net.Error timeout errors when dialing", func() {
errChan := make(chan error)
go func() {
_, err := quic.DialAddr(
context.Background(),
"localhost:12345",
getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
)
errChan <- err
}()
var err error
Eventually(errChan).Should(Receive(&err))
checkTimeoutError(err)
})
runClient := func(conn quic.Connection) error {
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
return err
}
data, err := io.ReadAll(str)
if err != nil {
return err
}
if !bytes.Equal(data, PRData) {
return fmt.Errorf("wrong data: %q vs %q", data, PRData)
}
return conn.CloseWithError(0, "done")
}
It("returns the context error when the context expires", func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
errChan := make(chan error)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:12345",
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
var err error
Eventually(errChan).Should(Receive(&err))
// This is not a net.Error timeout error
Expect(err).To(MatchError(context.DeadlineExceeded))
})
var cconn, sconn net.PacketConn
var err error
cconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer cconn.Close()
sconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer sconn.Close()
It("returns the context error when the context expires with 0RTT enabled", func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
errChan := make(chan error)
go func() {
_, err := quic.DialAddrEarly(
ctx,
"localhost:12345",
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()
var err error
Eventually(errChan).Should(Receive(&err))
// This is not a net.Error timeout error
Expect(err).To(MatchError(context.DeadlineExceeded))
})
maxPackets := mrand.Int31n(25)
t.Logf("blocking %s's connection after %d packets", pers, maxPackets)
switch pers {
case protocol.PerspectiveClient:
cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets}
case protocol.PerspectiveServer:
sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets}
}
It("returns net.Error timeout errors when an idle timeout occurs", func() {
const idleTimeout = 500 * time.Millisecond
ln, err := quic.Listen(
sconn,
getTLSConfig(),
getQuicConfig(&quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
KeepAlivePeriod: handshakeTimeout / 2,
DisablePathMTUDiscovery: true,
}),
)
require.NoError(t, err)
defer ln.Close()
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverErrChan := make(chan error, 1)
go func() { serverErrChan <- runServer(ln) }()
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
}()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
clientErrChan := make(chan error, 1)
go func() {
conn, err := quic.Dial(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
)
Expect(err).ToNot(HaveOccurred())
strIn, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
strOut, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = strIn.Read(make([]byte, 6))
Expect(err).ToNot(HaveOccurred())
drop.Store(true)
time.Sleep(2 * idleTimeout)
_, err = strIn.Write([]byte("test"))
checkTimeoutError(err)
_, err = strIn.Read([]byte{0})
checkTimeoutError(err)
_, err = strOut.Write([]byte("test"))
checkTimeoutError(err)
_, err = strOut.Read([]byte{0})
checkTimeoutError(err)
_, err = conn.OpenStream()
checkTimeoutError(err)
_, err = conn.OpenUniStream()
checkTimeoutError(err)
_, err = conn.AcceptStream(context.Background())
checkTimeoutError(err)
_, err = conn.AcceptUniStream(context.Background())
checkTimeoutError(err)
})
Context("timing out at the right time", func() {
var idleTimeout time.Duration
BeforeEach(func() {
idleTimeout = scaleDuration(500 * time.Millisecond)
})
It("times out after inactivity", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
counter, tr := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
Tracer: newTracer(tr),
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := conn.AcceptStream(context.Background())
checkTimeoutError(err)
close(done)
}()
Eventually(done, 2*idleTimeout).Should(BeClosed())
var lastAckElicitingPacketSentAt time.Time
for _, p := range counter.getSentShortHeaderPackets() {
var hasAckElicitingFrame bool
for _, f := range p.frames {
if _, ok := f.(*logging.AckFrame); ok {
continue
}
hasAckElicitingFrame = true
break
}
if hasAckElicitingFrame {
lastAckElicitingPacketSentAt = p.time
}
}
rcvdPackets := counter.getRcvdShortHeaderPackets()
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
// This is ok since we're dealing with a lossless connection here,
// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
Expect(max(time.Since(lastAckElicitingPacketSentAt), time.Since(lastPacketRcvdAt))).To(And(
BeNumerically(">=", idleTimeout),
BeNumerically("<", idleTimeout*6/5),
))
Consistently(serverConnClosed).ShouldNot(BeClosed())
// make the go routine return
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
It("times out after sending a packet", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(dir quicproxy.Direction, _ []byte) bool {
if dir == quicproxy.DirectionOutgoing {
return drop.Load()
}
return false
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
<-conn.Context().Done() // block until the connection is closed
close(serverConnClosed)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
// wait half the idle timeout, then send a packet
time.Sleep(idleTimeout / 2)
drop.Store(true)
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
// now make sure that the idle timeout is based on this packet
startTime := time.Now()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := conn.AcceptStream(context.Background())
checkTimeoutError(err)
close(done)
}()
Eventually(done, 2*idleTimeout).Should(BeClosed())
dur := time.Since(startTime)
Expect(dur).To(And(
BeNumerically(">=", idleTimeout),
BeNumerically("<", idleTimeout*12/10),
))
Consistently(serverConnClosed).ShouldNot(BeClosed())
// make the go routine return
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
})
It("does not time out if keepalive is set", func() {
const idleTimeout = 500 * time.Millisecond
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
var drop atomic.Bool
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
DropPacket: func(quicproxy.Direction, []byte) bool {
return drop.Load()
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
cconn,
ln.Addr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
KeepAlivePeriod: idleTimeout / 2,
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
KeepAlivePeriod: handshakeTimeout / 2,
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
// wait longer than the idle timeout
time.Sleep(3 * idleTimeout)
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Consistently(serverConnClosed).ShouldNot(BeClosed())
// idle timeout will still kick in if pings are dropped
drop.Store(true)
time.Sleep(2 * idleTimeout)
_, err = str.Write([]byte("foobar"))
checkTimeoutError(err)
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
Context("faulty packet conns", func() {
const handshakeTimeout = time.Second / 2
runServer := func(ln *quic.Listener) error {
conn, err := ln.Accept(context.Background())
if err != nil {
return err
}
str, err := conn.OpenUniStream()
if err != nil {
return err
}
defer str.Close()
_, err = str.Write(PRData)
return err
if err != nil {
clientErrChan <- err
return
}
clientErrChan <- runClient(conn)
}()
runClient := func(conn quic.Connection) error {
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
return err
}
data, err := io.ReadAll(str)
if err != nil {
return err
}
Expect(data).To(Equal(PRData))
return conn.CloseWithError(0, "done")
var clientErr error
select {
case clientErr = <-clientErrChan:
case <-time.After(5 * handshakeTimeout):
t.Fatal("timeout waiting for client error")
}
require.Error(t, clientErr)
if pers == protocol.PerspectiveClient {
require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error())
} else {
var nerr net.Error
require.True(t, errors.As(clientErr, &nerr))
require.True(t, nerr.Timeout())
}
require.Eventually(t, func() bool { return !areHandshakesRunning() }, 5*handshakeTimeout, 5*time.Millisecond)
select {
case serverErr := <-serverErrChan: // The handshake completed on the server side.
require.Error(t, serverErr)
if pers == protocol.PerspectiveServer {
require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error())
} else {
var nerr net.Error
require.True(t, errors.As(serverErr, &nerr))
require.True(t, nerr.Timeout())
}
It("deals with an erroring packet conn, on the server side", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
maxPackets := mrand.Int31n(25)
fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
ln, err := quic.Listen(
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
serverErrChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
serverErrChan <- runServer(ln)
}()
clientErrChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
DisablePathMTUDiscovery: true,
}),
)
if err != nil {
clientErrChan <- err
return
}
clientErrChan <- runClient(conn)
}()
var clientErr error
Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
Expect(clientErr).To(HaveOccurred())
nErr, ok := clientErr.(net.Error)
Expect(ok).To(BeTrue())
Expect(nErr.Timeout()).To(BeTrue())
select {
case serverErr := <-serverErrChan:
Expect(serverErr).To(HaveOccurred())
Expect(serverErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
defer ln.Close()
default:
Expect(ln.Close()).To(Succeed())
Eventually(serverErrChan).Should(Receive())
}
})
It("deals with an erroring packet conn, on the client side", func() {
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
KeepAlivePeriod: handshakeTimeout / 2,
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
serverErrChan <- runServer(ln)
}()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
maxPackets := mrand.Int31n(25)
fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
clientErrChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
conn, err := quic.Dial(
context.Background(),
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
ln.Addr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
if err != nil {
clientErrChan <- err
return
}
clientErrChan <- runClient(conn)
}()
var clientErr error
Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
Expect(clientErr).To(HaveOccurred())
Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
Eventually(areHandshakesRunning, 5*handshakeTimeout).Should(BeFalse())
select {
case serverErr := <-serverErrChan: // The handshake completed on the server side.
Expect(serverErr).To(HaveOccurred())
nErr, ok := serverErr.(net.Error)
Expect(ok).To(BeTrue())
Expect(nErr.Timeout()).To(BeTrue())
default: // The handshake didn't complete
Expect(ln.Close()).To(Succeed())
Eventually(serverErrChan).Should(Receive())
}
})
})
})
default: // The handshake didn't complete
require.NoError(t, ln.Close())
select {
case <-serverErrChan:
case <-time.After(time.Second):
t.Fatal("timeout waiting for server to close")
}
}
}

View File

@@ -9,6 +9,7 @@ import (
mrand "math/rand"
"net"
"sync"
"testing"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
@@ -17,30 +18,29 @@ import (
"github.com/quic-go/quic-go/metrics"
"github.com/quic-go/quic-go/qlog"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/require"
)
var _ = Describe("Tracer tests", func() {
func TestTracerHandshake(t *testing.T) {
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
enableQlog := mrand.Int()%2 != 0
enableMetrcis := mrand.Int()%2 != 0
enableMetrics := mrand.Int()%2 != 0
enableCustomTracer := mrand.Int()%2 != 0
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, metrics: %t, custom: %t\n", pers, enableQlog, enableMetrcis, enableCustomTracer)
t.Logf("%s using qlog: %t, metrics: %t, custom: %t", pers, enableQlog, enableMetrics, enableCustomTracer)
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
if enableQlog {
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %s\n", p, connID)
t.Logf("%s qlog tracer deciding to not trace connection %s", p, connID)
return nil
}
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %s\n", p, connID)
t.Logf("%s qlog tracing connection %s", p, connID)
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
})
}
if enableMetrcis {
if enableMetrics {
tracerConstructors = append(tracerConstructors, metrics.DefaultConnectionTracer)
}
if enableCustomTracer {
@@ -62,39 +62,21 @@ var _ = Describe("Tracer tests", func() {
}
for i := 0; i < 3; i++ {
It("handshakes with a random combination of tracers", func() {
t.Run(fmt.Sprintf("run %d", i+1), func(t *testing.T) {
if enableQlog {
Skip("This test sets tracers and won't produce any qlogs.")
t.Skip("This test sets tracers and won't produce any qlogs.")
}
quicClientConf := addTracers(protocol.PerspectiveClient, getQuicConfig(nil))
quicServerConf := addTracers(protocol.PerspectiveServer, getQuicConfig(nil))
serverChan := make(chan *quic.Listener)
serverDone := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(serverDone)
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf)
Expect(err).ToNot(HaveOccurred())
serverChan <- ln
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return
}
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}
}()
ln := <-serverChan
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ln.Close()) })
var wg sync.WaitGroup
wg.Add(3)
for i := 0; i < 3; i++ {
for j := 0; j < 3; j++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := quic.DialAddr(
@@ -103,18 +85,27 @@ var _ = Describe("Tracer tests", func() {
getTLSClientConfig(),
quicClientConf,
)
Expect(err).ToNot(HaveOccurred())
require.NoError(t, err)
defer conn.CloseWithError(0, "")
sconn, err := ln.Accept(context.Background())
if err != nil {
return
}
sstr, err := sconn.OpenUniStream()
require.NoError(t, err)
_, err = sstr.Write(PRData)
require.NoError(t, err)
require.NoError(t, sstr.Close())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
require.NoError(t, err)
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
require.NoError(t, err)
require.Equal(t, PRData, data)
}()
}
wg.Wait()
ln.Close()
Eventually(serverDone).Should(BeClosed())
})
}
})
}

View File

@@ -1,147 +0,0 @@
package self_test
import (
"context"
"fmt"
"io"
"net"
"sync"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Unidirectional Streams", func() {
const numStreams = 500
var (
server *quic.Listener
serverAddr string
)
BeforeEach(func() {
var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
dataForStream := func(id protocol.StreamID) []byte {
return GeneratePRData(10 * int(id))
}
runSendingPeer := func(conn quic.Connection) {
for i := 0; i < numStreams; i++ {
str, err := conn.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
_, err := str.Write(dataForStream(str.StreamID()))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
}
runReceivingPeer := func(conn quic.Connection) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(dataForStream(str.StreamID())))
}()
}
wg.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(conn)
conn.CloseWithError(0, "")
}()
client, err := quic.DialAddr(
context.Background(),
serverAddr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
<-client.Context().Done()
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(conn)
<-conn.Context().Done()
}()
client, err := quic.DialAddr(
context.Background(),
serverAddr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
client.CloseWithError(0, "")
})
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(conn)
close(done)
}()
runSendingPeer(conn)
<-done
close(done1)
}()
client, err := quic.DialAddr(
context.Background(),
serverAddr,
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})

File diff suppressed because it is too large Load Diff