forked from quic-go/quic-go
migrate integration tests away from Ginkgo (#4736)
* use require in benchmark tests * translate the QLOGDIR test * translate handshake tests * translate the handshake RTT tests * translate the early data test * translate the MTU tests * translate the key update test * translate the stateless reset tests * translate the packetization test * translate the close test * translate the resumption test * translate the tracer test * translate the connection ID length test * translate the RTT tests * translate the multiplexing tests * translate the drop tests * translate the handshake drop tests * translate the 0-RTT tests * translate the hotswap test * translate the stream test * translate the unidirectional stream test * translate the timeout tests * translate the MITM test * rewrite the datagram tests * translate the cancellation tests * translate the deadline tests * translate the test helpers
This commit is contained in:
@@ -7,15 +7,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkHandshake(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConfig, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
require.NoError(b, err)
|
||||
defer ln.Close()
|
||||
|
||||
connChan := make(chan quic.Connection, 1)
|
||||
@@ -30,9 +30,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
}()
|
||||
|
||||
conn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
require.NoError(b, err)
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
defer tr.Close()
|
||||
@@ -40,10 +38,9 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c, err := tr.Dial(context.Background(), ln.Addr(), tlsClientConfig, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
<-connChan
|
||||
require.NoError(b, err)
|
||||
serverConn := <-connChan
|
||||
serverConn.CloseWithError(0, "")
|
||||
c.CloseWithError(0, "")
|
||||
}
|
||||
}
|
||||
@@ -52,21 +49,20 @@ func BenchmarkStreamChurn(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConfig, &quic.Config{MaxIncomingStreams: 1e10})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
require.NoError(b, err)
|
||||
defer ln.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
conn, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil)
|
||||
require.NoError(b, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(b, err)
|
||||
defer serverConn.CloseWithError(0, "")
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
close(errChan)
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
str, err := serverConn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -74,22 +70,10 @@ func BenchmarkStreamChurn(b *testing.B) {
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := <-errChan; err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
str, err := c.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := str.Close(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
str, err := conn.OpenStreamSync(context.Background())
|
||||
require.NoError(b, err)
|
||||
require.NoError(b, str.Close())
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,74 +5,80 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
It("retransmits the CONNECTION_CLOSE packet", func() {
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
func TestConnectionCloseRetransmission(t *testing.T) {
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
dropped := make(chan []byte, 100)
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
|
||||
return 5 * time.Millisecond // 10ms RTT
|
||||
},
|
||||
DropPacket: func(dir quicproxy.Direction, b []byte) bool {
|
||||
if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing {
|
||||
dropped <- b
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
drop.Store(true)
|
||||
sconn.CloseWithError(1337, "closing")
|
||||
|
||||
// send 100 packets
|
||||
for i := 0; i < 100; i++ {
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
// Expect retransmissions of the CONNECTION_CLOSE for the
|
||||
// 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet)
|
||||
Eventually(dropped).Should(HaveLen(8))
|
||||
first := <-dropped
|
||||
for len(dropped) > 0 {
|
||||
Expect(<-dropped).To(Equal(first)) // these packets are all identical
|
||||
}
|
||||
var drop atomic.Bool
|
||||
dropped := make(chan []byte, 100)
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
|
||||
return 5 * time.Millisecond // 10ms RTT
|
||||
},
|
||||
DropPacket: func(dir quicproxy.Direction, b []byte) bool {
|
||||
if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing {
|
||||
dropped <- b
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
drop.Store(true)
|
||||
sconn.CloseWithError(1337, "closing")
|
||||
|
||||
// send 100 packets
|
||||
for i := 0; i < 100; i++ {
|
||||
str, err := conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Expect retransmissions of the CONNECTION_CLOSE for the
|
||||
// 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet)
|
||||
var packets [][]byte
|
||||
for i := 0; i < 8; i++ {
|
||||
select {
|
||||
case p := <-dropped:
|
||||
packets = append(packets, p)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for CONNECTION_CLOSE retransmission")
|
||||
}
|
||||
}
|
||||
|
||||
// verify all retransmitted packets were identical
|
||||
for i := 1; i < len(packets); i++ {
|
||||
require.Equal(t, packets[0], packets[i])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,124 +7,118 @@ import (
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type connIDGenerator struct {
|
||||
length int
|
||||
Length int
|
||||
}
|
||||
|
||||
var _ quic.ConnectionIDGenerator = &connIDGenerator{}
|
||||
|
||||
func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) {
|
||||
b := make([]byte, c.length)
|
||||
b := make([]byte, c.Length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
|
||||
return quic.ConnectionID{}, fmt.Errorf("generating conn ID failed: %w", err)
|
||||
}
|
||||
return protocol.ParseConnectionID(b), nil
|
||||
}
|
||||
|
||||
func (c *connIDGenerator) ConnectionIDLen() int {
|
||||
return c.length
|
||||
func (c *connIDGenerator) ConnectionIDLen() int { return c.Length }
|
||||
|
||||
func randomConnIDLen() int { return 2 + int(mrand.Int31n(19)) }
|
||||
|
||||
func TestConnectionIDsZeroLength(t *testing.T) {
|
||||
testTransferWithConnectionIDs(t, randomConnIDLen(), 0, nil, nil)
|
||||
}
|
||||
|
||||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
||||
func TestConnectionIDsRandomLengths(t *testing.T) {
|
||||
testTransferWithConnectionIDs(t, randomConnIDLen(), randomConnIDLen(), nil, nil)
|
||||
}
|
||||
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
addTracer(tr)
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return ln, func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
}
|
||||
func TestConnectionIDsCustomGenerator(t *testing.T) {
|
||||
testTransferWithConnectionIDs(t, 0, 0,
|
||||
&connIDGenerator{Length: randomConnIDLen()},
|
||||
&connIDGenerator{Length: randomConnIDLen()},
|
||||
)
|
||||
}
|
||||
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
func testTransferWithConnectionIDs(
|
||||
t *testing.T,
|
||||
serverConnIDLen, clientConnIDLen int,
|
||||
serverConnIDGenerator, clientConnIDGenerator quic.ConnectionIDGenerator,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
if serverConnIDGenerator != nil {
|
||||
t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen())
|
||||
} else {
|
||||
t.Logf("using %d byte connection ID for the server", serverConnIDLen)
|
||||
}
|
||||
if clientConnIDGenerator != nil {
|
||||
t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen())
|
||||
} else {
|
||||
t.Logf("using %d byte connection ID for the client", clientConnIDLen)
|
||||
}
|
||||
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
cl, err := tr.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.CloseWithError(0, "")
|
||||
str, err := cl.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
// setup server
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
serverTr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: serverConnIDLen,
|
||||
ConnectionIDGenerator: serverConnIDGenerator,
|
||||
}
|
||||
t.Cleanup(func() { serverTr.Close() })
|
||||
addTracer(serverTr)
|
||||
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, nil)
|
||||
})
|
||||
// setup client
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
clientConn, err := net.ListenUDP("udp", laddr)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { clientConn.Close() })
|
||||
clientTr := &quic.Transport{
|
||||
Conn: clientConn,
|
||||
ConnectionIDLength: clientConnIDLen,
|
||||
ConnectionIDGenerator: clientConnIDGenerator,
|
||||
}
|
||||
t.Cleanup(func() { clientTr.Close() })
|
||||
addTracer(clientTr)
|
||||
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), randomConnIDLen(), nil)
|
||||
})
|
||||
cl, err := clientTr.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: ln.Addr().(*net.UDPAddr).Port},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { cl.CloseWithError(0, "") })
|
||||
|
||||
It("downloads a file when both client and server use a custom connection ID generator", func() {
|
||||
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
|
||||
})
|
||||
})
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
|
||||
|
||||
go func() {
|
||||
serverStr.Write(PRData)
|
||||
serverStr.Close()
|
||||
}()
|
||||
|
||||
str, err := cl.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
}
|
||||
|
||||
@@ -1,187 +1,244 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
mrand "math/rand/v2"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Datagram test", func() {
|
||||
const concurrentSends = 100
|
||||
const maxDatagramSize = 250
|
||||
func TestDatagramNegotiation(t *testing.T) {
|
||||
t.Run("server enable, client enable", func(t *testing.T) {
|
||||
testDatagramNegotiation(t, true, true)
|
||||
})
|
||||
t.Run("server enable, client disable", func(t *testing.T) {
|
||||
testDatagramNegotiation(t, true, false)
|
||||
})
|
||||
t.Run("server disable, client enable", func(t *testing.T) {
|
||||
testDatagramNegotiation(t, false, true)
|
||||
})
|
||||
t.Run("server disable, client disable", func(t *testing.T) {
|
||||
testDatagramNegotiation(t, false, false)
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
serverConn, clientConn *net.UDPConn
|
||||
dropped, total atomic.Int32
|
||||
func testDatagramNegotiation(t *testing.T, serverEnableDatagram, clientEnableDatagram bool) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
server, err := quic.Listen(
|
||||
udpConn,
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: serverEnableDatagram}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ln, err := quic.Listen(
|
||||
serverConn,
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: enableDatagram}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
clientConn, err := quic.DialAddr(
|
||||
ctx,
|
||||
server.Addr().String(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: clientEnableDatagram}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.CloseWithError(0, "")
|
||||
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(accepted)
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer serverConn.CloseWithError(0, "")
|
||||
|
||||
if expectDatagramSupport {
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
|
||||
if enableDatagram {
|
||||
f := &wire.DatagramFrame{DataLenPresent: true}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentSends)
|
||||
for i := 0; i < concurrentSends; i++ {
|
||||
go func(i int) {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(i))
|
||||
Expect(conn.SendDatagram(b)).To(Succeed())
|
||||
}(i)
|
||||
}
|
||||
maxDatagramMessageSize := f.MaxDataLen(maxDatagramSize, conn.ConnectionState().Version)
|
||||
b := make([]byte, maxDatagramMessageSize+1)
|
||||
Expect(conn.SendDatagram(b)).To(MatchError(&quic.DatagramTooLargeError{
|
||||
MaxDatagramPayloadSize: int64(maxDatagramMessageSize),
|
||||
}))
|
||||
wg.Wait()
|
||||
}
|
||||
} else {
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
}
|
||||
}()
|
||||
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
// drop 10% of Short Header packets sent from the server
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
return false
|
||||
}
|
||||
// don't drop Long Header packets
|
||||
if wire.IsLongHeaderPacket(packet[0]) {
|
||||
return false
|
||||
}
|
||||
drop := mrand.Int()%10 == 0
|
||||
if drop {
|
||||
dropped.Add(1)
|
||||
}
|
||||
total.Add(1)
|
||||
return drop
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return proxy.LocalPort(), func() {
|
||||
Eventually(accepted).Should(BeClosed())
|
||||
proxy.Close()
|
||||
ln.Close()
|
||||
}
|
||||
if clientEnableDatagram {
|
||||
require.True(t, serverConn.ConnectionState().SupportsDatagrams)
|
||||
require.NoError(t, serverConn.SendDatagram([]byte("foo")))
|
||||
datagram, err := clientConn.ReceiveDatagram(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo"), datagram)
|
||||
} else {
|
||||
require.False(t, serverConn.ConnectionState().SupportsDatagrams)
|
||||
require.Error(t, serverConn.SendDatagram([]byte("foo")))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
if serverEnableDatagram {
|
||||
require.True(t, clientConn.ConnectionState().SupportsDatagrams)
|
||||
require.NoError(t, clientConn.SendDatagram([]byte("bar")))
|
||||
datagram, err := serverConn.ReceiveDatagram(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("bar"), datagram)
|
||||
} else {
|
||||
require.False(t, clientConn.ConnectionState().SupportsDatagrams)
|
||||
require.Error(t, clientConn.SendDatagram([]byte("bar")))
|
||||
}
|
||||
}
|
||||
|
||||
It("sends datagrams", func() {
|
||||
oldMaxDatagramSize := wire.MaxDatagramSize
|
||||
wire.MaxDatagramSize = maxDatagramSize
|
||||
proxyPort, close := startServerAndProxy(true, true)
|
||||
defer close()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
clientConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue())
|
||||
var counter int
|
||||
for {
|
||||
// Close the connection if no message is received for 100 ms.
|
||||
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") })
|
||||
if _, err := conn.ReceiveDatagram(context.Background()); err != nil {
|
||||
break
|
||||
func TestDatagramSizeLimit(t *testing.T) {
|
||||
const maxDatagramSize = 456
|
||||
originalMaxDatagramSize := wire.MaxDatagramSize
|
||||
wire.MaxDatagramSize = maxDatagramSize
|
||||
t.Cleanup(func() { wire.MaxDatagramSize = originalMaxDatagramSize })
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
server, err := quic.Listen(
|
||||
udpConn,
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
clientConn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
server.Addr().String(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.CloseWithError(0, "")
|
||||
|
||||
err = clientConn.SendDatagram(bytes.Repeat([]byte("a"), maxDatagramSize+100)) // definitely too large
|
||||
require.Error(t, err)
|
||||
var sizeErr *quic.DatagramTooLargeError
|
||||
require.ErrorAs(t, err, &sizeErr)
|
||||
require.InDelta(t, sizeErr.MaxDatagramPayloadSize, maxDatagramSize, 10)
|
||||
|
||||
require.NoError(t, clientConn.SendDatagram(bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize))))
|
||||
require.Error(t, clientConn.SendDatagram(bytes.Repeat([]byte("c"), int(sizeErr.MaxDatagramPayloadSize+1))))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer serverConn.CloseWithError(0, "")
|
||||
datagram, err := serverConn.ReceiveDatagram(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize)), datagram)
|
||||
}
|
||||
|
||||
func TestDatagramLoss(t *testing.T) {
|
||||
const rtt = 10 * time.Millisecond
|
||||
const numDatagrams = 100
|
||||
const datagramSize = 500
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
server, err := quic.Listen(
|
||||
udpConn,
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
var droppedIncoming, droppedOutgoing, total atomic.Int32
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
|
||||
return false
|
||||
}
|
||||
timer.Stop()
|
||||
counter++
|
||||
if len(packet) < datagramSize { // don't drop ACK-only packets
|
||||
return false
|
||||
}
|
||||
total.Add(1)
|
||||
if mrand.Int()%10 == 0 {
|
||||
switch dir {
|
||||
case quicproxy.DirectionIncoming:
|
||||
droppedIncoming.Add(1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
droppedOutgoing.Add(1)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
clientConn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.CloseWithError(0, "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(numDatagrams*time.Millisecond))
|
||||
defer cancel()
|
||||
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer serverConn.CloseWithError(0, "")
|
||||
|
||||
var clientDatagrams, serverDatagrams int
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(clientErrChan)
|
||||
for {
|
||||
if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
|
||||
clientErrChan <- err
|
||||
return
|
||||
}
|
||||
clientDatagrams++
|
||||
}
|
||||
}()
|
||||
|
||||
numDropped := int(dropped.Load())
|
||||
expVal := concurrentSends - numDropped
|
||||
fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, total.Load())
|
||||
fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, concurrentSends)
|
||||
Expect(counter).To(And(
|
||||
BeNumerically(">", expVal*9/10),
|
||||
BeNumerically("<", concurrentSends),
|
||||
))
|
||||
Eventually(conn.Context().Done).Should(BeClosed())
|
||||
wire.MaxDatagramSize = oldMaxDatagramSize
|
||||
})
|
||||
for i := 0; i < numDatagrams; i++ {
|
||||
payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
|
||||
require.NoError(t, clientConn.SendDatagram(payload))
|
||||
require.NoError(t, serverConn.SendDatagram(payload))
|
||||
time.Sleep(scaleDuration(time.Millisecond / 2))
|
||||
}
|
||||
|
||||
It("server can disable datagram", func() {
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
clientConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
for {
|
||||
if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
serverDatagrams++
|
||||
}
|
||||
}()
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
select {
|
||||
case err := <-clientErrChan:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case err := <-serverErrChan:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
It("client can disable datagram", func() {
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
clientConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
|
||||
Expect(conn.SendDatagram([]byte{0})).To(HaveOccurred())
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
numDroppedIncoming := droppedIncoming.Load()
|
||||
numDroppedOutgoing := droppedOutgoing.Load()
|
||||
t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load())
|
||||
assert.NotZero(t, numDroppedIncoming)
|
||||
assert.NotZero(t, numDroppedOutgoing)
|
||||
t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
|
||||
assert.InDelta(t, numDatagrams-numDroppedIncoming, serverDatagrams, numDatagrams/20, "datagrams received by the server")
|
||||
t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
|
||||
assert.InDelta(t, numDatagrams-numDroppedOutgoing, clientDatagrams, numDatagrams/20, "datagrams received by the client")
|
||||
}
|
||||
|
||||
@@ -1,217 +1,240 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Stream deadline tests", func() {
|
||||
setup := func() (serverStr, clientStr quic.Stream, close func()) {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan := make(chan quic.SendStream)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan <- str
|
||||
}()
|
||||
func setupDeadlineTest(t *testing.T) (serverStr, clientStr quic.Stream) {
|
||||
t.Helper()
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { server.Close() })
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientStr, err = conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(strChan).Should(Receive(&serverStr))
|
||||
return serverStr, clientStr, func() {
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
conn, err := quic.DialAddr(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { conn.CloseWithError(0, "") })
|
||||
clientStr, err = conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
|
||||
serverStr, err = serverConn.AcceptStream(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = serverStr.Read([]byte{0})
|
||||
require.NoError(t, err)
|
||||
return serverStr, clientStr
|
||||
}
|
||||
|
||||
func TestReadDeadlineSync(t *testing.T) {
|
||||
serverStr, clientStr := setupDeadlineTest(t)
|
||||
|
||||
const timeout = time.Millisecond
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := serverStr.Write(PRDataLong)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
var bytesRead int
|
||||
var timeoutCounter int
|
||||
buf := make([]byte, 1<<10)
|
||||
data := make([]byte, len(PRDataLong))
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
for bytesRead < len(PRDataLong) {
|
||||
n, err := clientStr.Read(buf)
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
copy(data[bytesRead:], buf[:n])
|
||||
bytesRead += n
|
||||
}
|
||||
require.Equal(t, PRDataLong, data)
|
||||
// make sure the test actually worked and Read actually ran into the deadline a few times
|
||||
t.Logf("ran into deadline %d times", timeoutCounter)
|
||||
require.GreaterOrEqual(t, timeoutCounter, 10)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDeadlineAsync(t *testing.T) {
|
||||
serverStr, clientStr := setupDeadlineTest(t)
|
||||
|
||||
const timeout = time.Millisecond
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := serverStr.Write(PRDataLong)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
var bytesRead int
|
||||
var timeoutCounter int
|
||||
buf := make([]byte, 1<<10)
|
||||
data := make([]byte, len(PRDataLong))
|
||||
received := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-received:
|
||||
return
|
||||
default:
|
||||
time.Sleep(timeout)
|
||||
}
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
}()
|
||||
|
||||
for bytesRead < len(PRDataLong) {
|
||||
n, err := clientStr.Read(buf)
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
copy(data[bytesRead:], buf[:n])
|
||||
bytesRead += n
|
||||
}
|
||||
|
||||
Context("read deadlines", func() {
|
||||
It("completes a transfer when the deadline is set", func() {
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
require.Equal(t, PRDataLong, data)
|
||||
close(received)
|
||||
|
||||
const timeout = time.Millisecond
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serverStr.Write(PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
// make sure the test actually worked and Read actually ran into the deadline a few times
|
||||
t.Logf("ran into deadline %d times", timeoutCounter)
|
||||
require.GreaterOrEqual(t, timeoutCounter, 10)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
var bytesRead int
|
||||
var timeoutCounter int
|
||||
buf := make([]byte, 1<<10)
|
||||
data := make([]byte, len(PRDataLong))
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
for bytesRead < len(PRDataLong) {
|
||||
n, err := clientStr.Read(buf)
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
copy(data[bytesRead:], buf[:n])
|
||||
bytesRead += n
|
||||
}
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
// make sure the test actually worked and Read actually ran into the deadline a few times
|
||||
Expect(timeoutCounter).To(BeNumerically(">=", 10))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
func TestWriteDeadlineSync(t *testing.T) {
|
||||
serverStr, clientStr := setupDeadlineTest(t)
|
||||
|
||||
It("completes a transfer when the deadline is set concurrently", func() {
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
const timeout = time.Millisecond
|
||||
|
||||
const timeout = time.Millisecond
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serverStr.Write(PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
data, err := io.ReadAll(serverStr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
if !bytes.Equal(PRDataLong, data) {
|
||||
errChan <- fmt.Errorf("data mismatch")
|
||||
}
|
||||
}()
|
||||
|
||||
var bytesRead int
|
||||
var timeoutCounter int
|
||||
buf := make([]byte, 1<<10)
|
||||
data := make([]byte, len(PRDataLong))
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
deadlineDone := make(chan struct{})
|
||||
received := make(chan struct{})
|
||||
go func() {
|
||||
defer close(deadlineDone)
|
||||
for {
|
||||
select {
|
||||
case <-received:
|
||||
return
|
||||
default:
|
||||
time.Sleep(timeout)
|
||||
}
|
||||
clientStr.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
}()
|
||||
|
||||
for bytesRead < len(PRDataLong) {
|
||||
n, err := clientStr.Read(buf)
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
copy(data[bytesRead:], buf[:n])
|
||||
bytesRead += n
|
||||
}
|
||||
close(received)
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
// make sure the test actually worked an Read actually ran into the deadline a few times
|
||||
Expect(timeoutCounter).To(BeNumerically(">=", 10))
|
||||
Eventually(deadlineDone).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("write deadlines", func() {
|
||||
It("completes a transfer when the deadline is set", func() {
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
data, err := io.ReadAll(serverStr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
var bytesWritten int
|
||||
var timeoutCounter int
|
||||
var bytesWritten int
|
||||
var timeoutCounter int
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
for bytesWritten < len(PRDataLong) {
|
||||
n, err := clientStr.Write(PRDataLong[bytesWritten:])
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
for bytesWritten < len(PRDataLong) {
|
||||
n, err := clientStr.Write(PRDataLong[bytesWritten:])
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
bytesWritten += n
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
bytesWritten += n
|
||||
}
|
||||
clientStr.Close()
|
||||
|
||||
// make sure the test actually worked and Write actually ran into the deadline a few times
|
||||
t.Logf("ran into deadline %d times", timeoutCounter)
|
||||
require.GreaterOrEqual(t, timeoutCounter, 10)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteDeadlineAsync(t *testing.T) {
|
||||
serverStr, clientStr := setupDeadlineTest(t)
|
||||
|
||||
const timeout = time.Millisecond
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
data, err := io.ReadAll(serverStr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
if !bytes.Equal(PRDataLong, data) {
|
||||
errChan <- fmt.Errorf("data mismatch")
|
||||
}
|
||||
}()
|
||||
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
readDone := make(chan struct{})
|
||||
deadlineDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(deadlineDone)
|
||||
for {
|
||||
select {
|
||||
case <-readDone:
|
||||
return
|
||||
default:
|
||||
time.Sleep(timeout)
|
||||
}
|
||||
clientStr.Close()
|
||||
// make sure the test actually worked an Read actually ran into the deadline a few times
|
||||
Expect(timeoutCounter).To(BeNumerically(">=", 10))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("completes a transfer when the deadline is set concurrently", func() {
|
||||
serverStr, clientStr, closeFn := setup()
|
||||
defer closeFn()
|
||||
|
||||
const timeout = time.Millisecond
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
data, err := io.ReadAll(serverStr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
close(readDone)
|
||||
}()
|
||||
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
deadlineDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(deadlineDone)
|
||||
for {
|
||||
select {
|
||||
case <-readDone:
|
||||
return
|
||||
default:
|
||||
time.Sleep(timeout)
|
||||
}
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
var bytesWritten int
|
||||
var timeoutCounter int
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
for bytesWritten < len(PRDataLong) {
|
||||
n, err := clientStr.Write(PRDataLong[bytesWritten:])
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
bytesWritten += n
|
||||
}
|
||||
clientStr.Close()
|
||||
// make sure the test actually worked an Read actually ran into the deadline a few times
|
||||
Expect(timeoutCounter).To(BeNumerically(">=", 10))
|
||||
Eventually(readDone).Should(BeClosed())
|
||||
Eventually(deadlineDone).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
var bytesWritten int
|
||||
var timeoutCounter int
|
||||
clientStr.SetWriteDeadline(time.Now().Add(timeout))
|
||||
for bytesWritten < len(PRDataLong) {
|
||||
n, err := clientStr.Write(PRDataLong[bytesWritten:])
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
timeoutCounter++
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
bytesWritten += n
|
||||
}
|
||||
clientStr.Close()
|
||||
|
||||
close(readDone)
|
||||
|
||||
// make sure the test actually worked and Write actually ran into the deadline a few times
|
||||
t.Logf("ran into deadline %d times", timeoutCounter)
|
||||
require.GreaterOrEqual(t, timeoutCounter, 10)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,121 +3,92 @@ package self_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func randomDuration(min, max time.Duration) time.Duration {
|
||||
return min + time.Duration(rand.Int63n(int64(max-min)))
|
||||
}
|
||||
func TestDropTests(t *testing.T) {
|
||||
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing} {
|
||||
t.Run(fmt.Sprintf("in %s direction", direction), func(t *testing.T) {
|
||||
const numMessages = 15
|
||||
const rtt = 10 * time.Millisecond
|
||||
|
||||
var _ = Describe("Drop Tests", func() {
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
ln *quic.Listener
|
||||
)
|
||||
messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond)
|
||||
dropDuration := randomDuration(messageInterval*3/2, 2*time.Second)
|
||||
dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2)
|
||||
t.Logf("sending a message every %s, %d times", messageInterval, numMessages)
|
||||
t.Logf("dropping packets for %s, after a delay of %s", dropDuration, dropDelay)
|
||||
startTime := time.Now()
|
||||
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback) {
|
||||
var err error
|
||||
ln, err = quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
|
||||
return 5 * time.Millisecond // 10ms RTT
|
||||
},
|
||||
DropPacket: dropCallback,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
// The purpose of this test is to create a lot of tails, by sending 1 byte messages.
|
||||
// The interval, the length of the drop period, and the time when the drop period starts are randomized.
|
||||
// To cover different scenarios, repeat this test a few times.
|
||||
for rep := 0; rep < 3; rep++ {
|
||||
It(fmt.Sprintf("sends short messages, dropping packets in %s direction", direction), func() {
|
||||
const numMessages = 15
|
||||
|
||||
messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond)
|
||||
dropDuration := randomDuration(messageInterval*3/2, 2*time.Second)
|
||||
dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2) // makes sure we don't interfere with the handshake
|
||||
fmt.Fprintf(GinkgoWriter, "Sending a message every %s, %d times.\n", messageInterval, numMessages)
|
||||
fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay)
|
||||
startTime := time.Now()
|
||||
|
||||
var numDroppedPackets atomic.Int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var numDroppedPackets atomic.Int32
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
|
||||
DropPacket: func(d quicproxy.Direction, b []byte) bool {
|
||||
if !d.Is(direction) {
|
||||
return false
|
||||
}
|
||||
if wire.IsLongHeaderPacket(b[0]) { // don't interfere with the handshake
|
||||
return false
|
||||
}
|
||||
drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration))
|
||||
if drop {
|
||||
numDroppedPackets.Add(1)
|
||||
}
|
||||
return drop
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
n, err := str.Write([]byte{i})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(1))
|
||||
time.Sleep(messageInterval)
|
||||
}
|
||||
<-done
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
b := []byte{0}
|
||||
n, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(1))
|
||||
Expect(b[0]).To(Equal(i))
|
||||
}
|
||||
close(done)
|
||||
numDropped := numDroppedPackets.Load()
|
||||
fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped)
|
||||
Expect(numDropped).To(BeNumerically(">", 0))
|
||||
},
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
if _, err := serverStr.Write([]byte{i}); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
time.Sleep(messageInterval)
|
||||
}
|
||||
}()
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
b := []byte{0}
|
||||
n, err := str.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, i, b[0])
|
||||
}
|
||||
numDropped := numDroppedPackets.Load()
|
||||
t.Logf("dropped %d packets", numDropped)
|
||||
require.NotZero(t, numDropped)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,64 +5,71 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("early data", func() {
|
||||
func TestEarlyData(t *testing.T) {
|
||||
const rtt = 80 * time.Millisecond
|
||||
ln, err := quic.ListenAddrEarly("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
It("sends 0.5-RTT data", func() {
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("early data"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
// make sure the Write finished before the handshake completed
|
||||
Expect(conn.HandshakeComplete()).ToNot(BeClosed())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
}()
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("early data")))
|
||||
conn.CloseWithError(0, "")
|
||||
Eventually(done).Should(BeClosed())
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
connChan := make(chan quic.EarlyConnection)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
connChan <- conn
|
||||
}()
|
||||
|
||||
clientConn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var serverConn quic.EarlyConnection
|
||||
select {
|
||||
case serverConn = <-connChan:
|
||||
case err := <-errChan:
|
||||
t.Fatalf("error accepting connection: %s", err)
|
||||
}
|
||||
str, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("early data"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, str.Close())
|
||||
// the write should have completed before the handshake
|
||||
select {
|
||||
case <-serverConn.HandshakeComplete():
|
||||
t.Fatal("handshake shouldn't be completed yet")
|
||||
default:
|
||||
}
|
||||
|
||||
clientStr, err := clientConn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(clientStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("early data"), data)
|
||||
|
||||
clientConn.CloseWithError(0, "")
|
||||
<-serverConn.Context().Done()
|
||||
}
|
||||
|
||||
264
integrationtests/self/handshake_context_test.go
Normal file
264
integrationtests/self/handshake_context_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandshakeContextTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
ctx,
|
||||
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
require.ErrorIs(t, <-errChan, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestHandshakeCancellationError(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
ctx,
|
||||
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
cancel(errors.New("application cancelled"))
|
||||
require.EqualError(t, <-errChan, "application cancelled")
|
||||
}
|
||||
|
||||
func TestConnContextOnServerSide(t *testing.T) {
|
||||
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
|
||||
tlsGetCertificateContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
connContextChan := make(chan context.Context, 1)
|
||||
streamContextChan := make(chan context.Context, 1)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
server, err := tr.Listen(
|
||||
&tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
tlsGetConfigForClientContextChan <- info.Context()
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
tlsGetCertificateContextChan <- info.Context()
|
||||
return &tlsConf.Certificates[0], nil
|
||||
}
|
||||
return tlsConf, nil
|
||||
},
|
||||
},
|
||||
getQuicConfig(&quic.Config{
|
||||
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tracerContextChan <- ctx
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
c, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
connContextChan <- serverConn.Context()
|
||||
str, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
streamContextChan <- str.Context()
|
||||
str.Write([]byte{1, 2, 3})
|
||||
|
||||
_, err = c.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
c.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
t.Helper()
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-c:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for context")
|
||||
}
|
||||
|
||||
val := ctx.Value("foo")
|
||||
require.NotNil(t, val)
|
||||
v := val.(string)
|
||||
require.Equal(t, "bar", v)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for context to be done")
|
||||
}
|
||||
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
require.ErrorAs(t, ctxErr, &appErr)
|
||||
require.Equal(t, quic.ApplicationErrorCode(1337), appErr.ErrorCode)
|
||||
}
|
||||
|
||||
checkContext(connContextChan, true)
|
||||
checkContext(tracerContextChan, true)
|
||||
checkContext(streamContextChan, true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes.
|
||||
checkContext(tlsGetConfigForClientContextChan, false)
|
||||
checkContext(tlsGetCertificateContextChan, false)
|
||||
}
|
||||
|
||||
// Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully.
|
||||
func TestConnContextFreshContext(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
conn.CloseWithError(1337, "bye")
|
||||
}()
|
||||
|
||||
c, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-c.Context().Done():
|
||||
case err := <-errChan:
|
||||
t.Fatalf("accept failed: %v", err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextOnClientSide(t *testing.T) {
|
||||
tlsServerConf := getTLSConfig()
|
||||
tlsServerConf.ClientAuth = tls.RequestClientCert
|
||||
server, err := quic.ListenAddr("localhost:0", tlsServerConf, getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
tlsContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
tlsContextChan <- info.Context()
|
||||
return &tlsServerConf.Certificates[0], nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.WithValue(context.Background(), "foo", "bar")) //nolint:staticcheck
|
||||
conn, err := quic.DialAddr(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tracerContextChan <- ctx
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
cancel()
|
||||
|
||||
// Make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
t.Fatal("context should not be cancelled")
|
||||
default:
|
||||
}
|
||||
|
||||
checkContext := func(ctx context.Context, checkCancellationCause bool) {
|
||||
t.Helper()
|
||||
val := ctx.Value("foo")
|
||||
require.NotNil(t, val)
|
||||
require.Equal(t, "bar", val.(string))
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
require.ErrorAs(t, ctxErr, &appErr)
|
||||
require.EqualValues(t, 1337, appErr.ErrorCode)
|
||||
}
|
||||
|
||||
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
t.Helper()
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-c:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for context")
|
||||
}
|
||||
checkContext(ctx, checkCancellationCause)
|
||||
}
|
||||
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
conn.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext(conn.Context(), true)
|
||||
checkContext(str.Context(), true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes
|
||||
checkContextFromChan(tlsContextChan, false)
|
||||
checkContextFromChan(tracerContextChan, false)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
@@ -16,282 +18,266 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
||||
func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (_ *quic.Listener, proxyPort int) {
|
||||
t.Helper()
|
||||
conf := getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
DisablePathMTUDiscovery: true,
|
||||
})
|
||||
var tlsConf *tls.Config
|
||||
if longCertChain {
|
||||
tlsConf = getTLSConfigWithLongCertChain()
|
||||
} else {
|
||||
tlsConf = getTLSConfig()
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return doRetry },
|
||||
}
|
||||
t.Cleanup(func() { tr.Close() })
|
||||
ln, err := tr.Listen(tlsConf, conf)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { ln.Close() })
|
||||
|
||||
type applicationProtocol struct {
|
||||
name string
|
||||
run func(ln *quic.Listener, port int)
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: dropCallback,
|
||||
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { proxy.Close() })
|
||||
return ln, proxy.LocalPort()
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake drop tests", func() {
|
||||
func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, port int, timeout time.Duration, data []byte) {
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer str.Close()
|
||||
_, err := str.Write(data)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
serverConn, err := ln.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.AcceptUniStream(ctx)
|
||||
require.NoError(t, err)
|
||||
b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, b, data)
|
||||
serverConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, port int, timeout time.Duration, data []byte) {
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer conn.CloseWithError(0, "")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
str, err := conn.AcceptUniStream(ctx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(b, data) {
|
||||
errChan <- fmt.Errorf("data mismatch: %x != %x", b, data)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = serverStr.Write(data)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, serverStr.Close())
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(timeout):
|
||||
t.Fatal("server connection not closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
case <-time.After(timeout):
|
||||
t.Fatal("server connection not closed")
|
||||
}
|
||||
}
|
||||
|
||||
func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, port int, timeout time.Duration) {
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.CloseWithError(0, "")
|
||||
serverConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func dropCallbackDropNthPacket(direction quicproxy.Direction, n int) quicproxy.DropCallback {
|
||||
var incoming, outgoing atomic.Int32
|
||||
return func(d quicproxy.Direction, packet []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = incoming.Add(1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = outgoing.Add(1)
|
||||
}
|
||||
return p == int32(n) && d.Is(direction)
|
||||
}
|
||||
}
|
||||
|
||||
func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallback {
|
||||
const maxSequentiallyDropped = 10
|
||||
var mx sync.Mutex
|
||||
var incoming, outgoing int
|
||||
return func(d quicproxy.Direction, _ []byte) bool {
|
||||
drop := mrand.Int63n(int64(3)) == 0
|
||||
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
// never drop more than 10 consecutive packets
|
||||
if d.Is(quicproxy.DirectionIncoming) {
|
||||
if drop {
|
||||
incoming++
|
||||
if incoming > maxSequentiallyDropped {
|
||||
drop = false
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
incoming = 0
|
||||
}
|
||||
}
|
||||
if d.Is(quicproxy.DirectionOutgoing) {
|
||||
if drop {
|
||||
outgoing++
|
||||
if outgoing > maxSequentiallyDropped {
|
||||
drop = false
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
outgoing = 0
|
||||
}
|
||||
}
|
||||
return drop
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeWithPacketLoss(t *testing.T) {
|
||||
data := GeneratePRData(5000)
|
||||
const timeout = 2 * time.Minute
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
|
||||
conf := getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
})
|
||||
var tlsConf *tls.Config
|
||||
if longCertChain {
|
||||
tlsConf = getTLSConfigWithLongCertChain()
|
||||
} else {
|
||||
tlsConf = getTLSConfig()
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
if doRetry {
|
||||
tr.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||
}
|
||||
ln, err = tr.Listen(tlsConf, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: dropCallback,
|
||||
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
|
||||
return 10 * time.Millisecond
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return ln, proxy.LocalPort(), func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
conn.Close()
|
||||
proxy.Close()
|
||||
}
|
||||
type dropPattern struct {
|
||||
name string
|
||||
fn quicproxy.DropCallback
|
||||
}
|
||||
|
||||
clientSpeaksFirst := &applicationProtocol{
|
||||
name: "client speaks first",
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal(data))
|
||||
serverConnChan <- conn
|
||||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
|
||||
var serverConn quic.Connection
|
||||
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
|
||||
conn.CloseWithError(0, "")
|
||||
serverConn.CloseWithError(0, "")
|
||||
},
|
||||
type serverConfig struct {
|
||||
longCertChain bool
|
||||
doRetry bool
|
||||
}
|
||||
|
||||
serverSpeaksFirst := &applicationProtocol{
|
||||
name: "server speaks first",
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
serverConnChan <- conn
|
||||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal(data))
|
||||
for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} {
|
||||
for _, dropPattern := range []dropPattern{
|
||||
{name: "drop 1st packet", fn: dropCallbackDropNthPacket(direction, 1)},
|
||||
{name: "drop 2nd packet", fn: dropCallbackDropNthPacket(direction, 2)},
|
||||
{name: "drop 1/3 of packets", fn: dropCallbackDropOneThird(direction)},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%s in %s direction", dropPattern.name, direction), func(t *testing.T) {
|
||||
for _, conf := range []serverConfig{
|
||||
{longCertChain: false, doRetry: true},
|
||||
{longCertChain: false, doRetry: false},
|
||||
{longCertChain: true, doRetry: false},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("retry: %t", conf.doRetry), func(t *testing.T) {
|
||||
t.Run("client speaks first", func(t *testing.T) {
|
||||
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
|
||||
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, timeout, data)
|
||||
})
|
||||
|
||||
var serverConn quic.Connection
|
||||
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
|
||||
conn.CloseWithError(0, "")
|
||||
serverConn.CloseWithError(0, "")
|
||||
},
|
||||
}
|
||||
t.Run("server speaks first", func(t *testing.T) {
|
||||
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
|
||||
dropTestProtocolServerSpeaksFirst(t, ln, proxyPort, timeout, data)
|
||||
})
|
||||
|
||||
nobodySpeaks := &applicationProtocol{
|
||||
name: "nobody speaks",
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var serverConn quic.Connection
|
||||
Eventually(serverConnChan, timeout).Should(Receive(&serverConn))
|
||||
// both server and client accepted a connection. Close now.
|
||||
conn.CloseWithError(0, "")
|
||||
serverConn.CloseWithError(0, "")
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
for _, dr := range []bool{true, false} {
|
||||
doRetry := dr
|
||||
desc := "when using Retry"
|
||||
if !dr {
|
||||
desc = "when not using Retry"
|
||||
}
|
||||
|
||||
Context(desc, func() {
|
||||
for _, lcc := range []bool{false, true} {
|
||||
longCertChain := lcc
|
||||
|
||||
Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
|
||||
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
|
||||
app := a
|
||||
|
||||
Context(app.name, func() {
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing atomic.Int32
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = incoming.Add(1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = outgoing.Add(1)
|
||||
}
|
||||
return p == 1 && d.Is(direction)
|
||||
}, doRetry, longCertChain)
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing atomic.Int32
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
switch d {
|
||||
case quicproxy.DirectionIncoming:
|
||||
p = incoming.Add(1)
|
||||
case quicproxy.DirectionOutgoing:
|
||||
p = outgoing.Add(1)
|
||||
}
|
||||
return p == 2 && d.Is(direction)
|
||||
}, doRetry, longCertChain)
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
|
||||
const maxSequentiallyDropped = 10
|
||||
var mx sync.Mutex
|
||||
var incoming, outgoing int
|
||||
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
drop := mrand.Int63n(int64(3)) == 0
|
||||
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
// never drop more than 10 consecutive packets
|
||||
if d.Is(quicproxy.DirectionIncoming) {
|
||||
if drop {
|
||||
incoming++
|
||||
if incoming > maxSequentiallyDropped {
|
||||
drop = false
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
incoming = 0
|
||||
}
|
||||
}
|
||||
if d.Is(quicproxy.DirectionOutgoing) {
|
||||
if drop {
|
||||
outgoing++
|
||||
if outgoing > maxSequentiallyDropped {
|
||||
drop = false
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
outgoing = 0
|
||||
}
|
||||
}
|
||||
return drop
|
||||
}, doRetry, longCertChain)
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
})
|
||||
}
|
||||
t.Run("nobody speaks", func(t *testing.T) {
|
||||
ln, proxyPort := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
|
||||
dropTestProtocolNobodySpeaks(t, ln, proxyPort, timeout)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() {
|
||||
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
|
||||
defer func() {
|
||||
wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient
|
||||
}()
|
||||
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
|
||||
mrand.New(mrand.NewSource(GinkgoRandomSeed())).Read(b)
|
||||
wire.AdditionalTransportParametersClient = map[uint64][]byte{
|
||||
// Avoid random collisions with the greased transport parameters.
|
||||
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
|
||||
}
|
||||
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
if d == quicproxy.DirectionOutgoing {
|
||||
return false
|
||||
}
|
||||
return mrand.Intn(3) == 0
|
||||
}, false, false)
|
||||
defer closeFn()
|
||||
clientSpeaksFirst.run(ln, proxyPort)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostQuantumClientHello(t *testing.T) {
|
||||
origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
|
||||
t.Cleanup(func() { wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient })
|
||||
|
||||
b := make([]byte, 2500) // the ClientHello will now span across 3 packets
|
||||
mrand.New(mrand.NewSource(time.Now().UnixNano())).Read(b)
|
||||
wire.AdditionalTransportParametersClient = map[uint64][]byte{
|
||||
// Avoid random collisions with the greased transport parameters.
|
||||
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
|
||||
}
|
||||
|
||||
ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false)
|
||||
dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, time.Minute, GeneratePRData(5000))
|
||||
}
|
||||
|
||||
@@ -6,192 +6,175 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Handshake RTT tests", func() {
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
serverConfig *quic.Config
|
||||
serverTLSConfig *tls.Config
|
||||
func handshakeWithRTT(t *testing.T, serverAddr net.Addr, tlsConf *tls.Config, quicConf *quic.Config, rtt time.Duration) quic.Connection {
|
||||
t.Helper()
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr.String(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { proxy.Close() })
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
quicConf,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { conn.CloseWithError(0, "") })
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestHandshakeRTTWithoutRetry(t *testing.T) {
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
clientConfig := getQuicConfig(&quic.Config{
|
||||
GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
require.False(t, info.AddrVerified)
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
|
||||
const rtt = 400 * time.Millisecond
|
||||
start := time.Now()
|
||||
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), clientConfig, rtt)
|
||||
rtts := time.Since(start).Seconds() / rtt.Seconds()
|
||||
require.GreaterOrEqual(t, rtts, float64(1))
|
||||
require.Less(t, rtts, float64(2))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig = getQuicConfig(nil)
|
||||
serverTLSConfig = getTLSConfig()
|
||||
func TestHandshakeRTTWithRetry(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
clientConfig := getQuicConfig(&quic.Config{
|
||||
GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
require.True(t, info.AddrVerified)
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
const rtt = 400 * time.Millisecond
|
||||
start := time.Now()
|
||||
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), clientConfig, rtt)
|
||||
rtts := time.Since(start).Seconds() / rtt.Seconds()
|
||||
require.GreaterOrEqual(t, rtts, float64(2))
|
||||
require.Less(t, rtts, float64(3))
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
func TestHandshakeRTTWithHelloRetryRequest(t *testing.T) {
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
||||
|
||||
runProxy := func(serverAddr net.Addr) {
|
||||
var err error
|
||||
// start the proxy
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr.String(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConf, getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
const rtt = 400 * time.Millisecond
|
||||
start := time.Now()
|
||||
handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
|
||||
rtts := time.Since(start).Seconds() / rtt.Seconds()
|
||||
require.GreaterOrEqual(t, rtts, float64(2))
|
||||
require.Less(t, rtts, float64(3))
|
||||
}
|
||||
|
||||
func TestHandshakeRTTReceiveMessage(t *testing.T) {
|
||||
sendAndReceive := func(t *testing.T, serverConn, clientConn quic.Connection) {
|
||||
t.Helper()
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = serverStr.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, serverStr.Close())
|
||||
|
||||
str, err := clientConn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foobar"), data)
|
||||
}
|
||||
|
||||
expectDurationInRTTs := func(startTime time.Time, num int) {
|
||||
testDuration := time.Since(startTime)
|
||||
rtts := float32(testDuration) / float32(rtt)
|
||||
Expect(rtts).To(SatisfyAll(
|
||||
BeNumerically(">=", num),
|
||||
BeNumerically("<", num+1),
|
||||
))
|
||||
}
|
||||
t.Run("using ListenAddr", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs with Retry", func() {
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||
connChan := make(chan quic.Connection, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
t.Logf("failed to accept connection: %s", err)
|
||||
close(connChan)
|
||||
return
|
||||
}
|
||||
connChan <- conn
|
||||
}()
|
||||
|
||||
const rtt = 400 * time.Millisecond
|
||||
start := time.Now()
|
||||
conn := handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
|
||||
serverConn := <-connChan
|
||||
if serverConn == nil {
|
||||
t.Fatal("serverConn is nil")
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
sendAndReceive(t, serverConn, conn)
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
Expect(info.AddrVerified).To(BeTrue())
|
||||
return nil, nil
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
rtts := time.Since(start).Seconds() / rtt.Seconds()
|
||||
require.GreaterOrEqual(t, rtts, float64(2))
|
||||
require.Less(t, rtts, float64(3))
|
||||
})
|
||||
|
||||
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
|
||||
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
t.Run("using ListenAddrEarly", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddrEarly("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
Expect(info.AddrVerified).To(BeFalse())
|
||||
return nil, nil
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
})
|
||||
|
||||
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
|
||||
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
||||
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
})
|
||||
|
||||
It("receives the first message from the server after 2 RTTs, when the server uses ListenAddr", func() {
|
||||
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
connChan := make(chan quic.Connection, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
if err != nil {
|
||||
t.Logf("failed to accept connection: %s", err)
|
||||
close(connChan)
|
||||
return
|
||||
}
|
||||
connChan <- conn
|
||||
}()
|
||||
defer ln.Close()
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
const rtt = 400 * time.Millisecond
|
||||
start := time.Now()
|
||||
conn := handshakeWithRTT(t, ln.Addr(), getTLSClientConfig(), getQuicConfig(nil), rtt)
|
||||
serverConn := <-connChan
|
||||
if serverConn == nil {
|
||||
t.Fatal("serverConn is nil")
|
||||
}
|
||||
sendAndReceive(t, serverConn, conn)
|
||||
|
||||
took := time.Since(start)
|
||||
rtts := float64(took) / float64(rtt)
|
||||
require.GreaterOrEqual(t, rtts, float64(1))
|
||||
require.Less(t, rtts, float64(2))
|
||||
})
|
||||
|
||||
It("receives the first message from the server after 1 RTT, when the server uses ListenAddrEarly", func() {
|
||||
ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Check the ALPN now. This is probably what an application would do.
|
||||
// It makes sure that ConnectionState does not block until the handshake completes.
|
||||
Expect(conn.ConnectionState().TLS.NegotiatedProtocol).To(Equal(alpn))
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
defer ln.Close()
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,170 +0,0 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
)
|
||||
|
||||
type listenerWrapper struct {
|
||||
http3.QUICEarlyListener
|
||||
listenerClosed bool
|
||||
count atomic.Int32
|
||||
}
|
||||
|
||||
func (ln *listenerWrapper) Close() error {
|
||||
ln.listenerClosed = true
|
||||
return ln.QUICEarlyListener.Close()
|
||||
}
|
||||
|
||||
func (ln *listenerWrapper) Faker() *fakeClosingListener {
|
||||
ln.count.Add(1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &fakeClosingListener{
|
||||
listenerWrapper: ln,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClosingListener struct {
|
||||
*listenerWrapper
|
||||
closed atomic.Bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
|
||||
return ln.listenerWrapper.Accept(ln.ctx)
|
||||
}
|
||||
|
||||
func (ln *fakeClosingListener) Close() error {
|
||||
if ln.closed.CompareAndSwap(false, true) {
|
||||
ln.cancel()
|
||||
if ln.listenerWrapper.count.Add(-1) == 0 {
|
||||
ln.listenerWrapper.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||
var (
|
||||
mux1 *http.ServeMux
|
||||
mux2 *http.ServeMux
|
||||
client *http.Client
|
||||
rt *http3.Transport
|
||||
server1 *http3.Server
|
||||
server2 *http3.Server
|
||||
ln *listenerWrapper
|
||||
port string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mux1 = http.NewServeMux()
|
||||
mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
|
||||
})
|
||||
|
||||
mux2 = http.NewServeMux()
|
||||
mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
|
||||
})
|
||||
|
||||
server1 = &http3.Server{
|
||||
Handler: mux1,
|
||||
QUICConfig: getQuicConfig(nil),
|
||||
}
|
||||
server2 = &http3.Server{
|
||||
Handler: mux2,
|
||||
QUICConfig: getQuicConfig(nil),
|
||||
}
|
||||
|
||||
tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
|
||||
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
|
||||
ln = &listenerWrapper{QUICEarlyListener: quicln}
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(rt.Close()).NotTo(HaveOccurred())
|
||||
Expect(ln.Close()).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
DisableCompression: true,
|
||||
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
}
|
||||
client = &http.Client{Transport: rt}
|
||||
})
|
||||
|
||||
It("hotswap works", func() {
|
||||
// open first server and make single request to it
|
||||
fake1 := ln.Faker()
|
||||
stoppedServing1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server1.ServeListener(fake1)
|
||||
close(stoppedServing1)
|
||||
}()
|
||||
|
||||
resp, err := client.Get("https://localhost:" + port + "/hello1")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal("Hello, World 1!\n"))
|
||||
|
||||
// open second server with same underlying listener,
|
||||
// make sure it opened and both servers are currently running
|
||||
fake2 := ln.Faker()
|
||||
stoppedServing2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server2.ServeListener(fake2)
|
||||
close(stoppedServing2)
|
||||
}()
|
||||
|
||||
Consistently(stoppedServing1).ShouldNot(BeClosed())
|
||||
Consistently(stoppedServing2).ShouldNot(BeClosed())
|
||||
|
||||
// now close first server, no errors should occur here
|
||||
// and only the fake listener should be closed
|
||||
Expect(server1.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing1).Should(BeClosed())
|
||||
Expect(fake1.closed.Load()).To(BeTrue())
|
||||
Expect(fake2.closed.Load()).To(BeFalse())
|
||||
Expect(ln.listenerClosed).ToNot(BeTrue())
|
||||
Expect(client.Transport.(*http3.Transport).Close()).NotTo(HaveOccurred())
|
||||
|
||||
// verify that new connections are being initiated from the second server now
|
||||
resp, err = client.Get("https://localhost:" + port + "/hello2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal("Hello, World 2!\n"))
|
||||
|
||||
// close the other server - both the fake and the actual listeners must close now
|
||||
Expect(server2.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing2).Should(BeClosed())
|
||||
Expect(fake2.closed.Load()).To(BeTrue())
|
||||
Expect(ln.listenerClosed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
160
integrationtests/self/http_hotswap_test.go
Normal file
160
integrationtests/self/http_hotswap_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type listenerWrapper struct {
|
||||
http3.QUICEarlyListener
|
||||
listenerClosed bool
|
||||
count atomic.Int32
|
||||
}
|
||||
|
||||
func (ln *listenerWrapper) Close() error {
|
||||
ln.listenerClosed = true
|
||||
return ln.QUICEarlyListener.Close()
|
||||
}
|
||||
|
||||
func (ln *listenerWrapper) Faker() *fakeClosingListener {
|
||||
ln.count.Add(1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &fakeClosingListener{
|
||||
listenerWrapper: ln,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClosingListener struct {
|
||||
*listenerWrapper
|
||||
closed atomic.Bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
|
||||
return ln.listenerWrapper.Accept(ln.ctx)
|
||||
}
|
||||
|
||||
func (ln *fakeClosingListener) Close() error {
|
||||
if ln.closed.CompareAndSwap(false, true) {
|
||||
ln.cancel()
|
||||
if ln.listenerWrapper.count.Add(-1) == 0 {
|
||||
ln.listenerWrapper.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHTTP3ServerHotswap(t *testing.T) {
|
||||
mux1 := http.NewServeMux()
|
||||
mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
|
||||
})
|
||||
|
||||
mux2 := http.NewServeMux()
|
||||
mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
|
||||
})
|
||||
|
||||
server1 := &http3.Server{
|
||||
Handler: mux1,
|
||||
QUICConfig: getQuicConfig(nil),
|
||||
}
|
||||
server2 := &http3.Server{
|
||||
Handler: mux2,
|
||||
QUICConfig: getQuicConfig(nil),
|
||||
}
|
||||
|
||||
tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
|
||||
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
ln := &listenerWrapper{QUICEarlyListener: quicln}
|
||||
port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
|
||||
|
||||
rt := &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
DisableCompression: true,
|
||||
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
}
|
||||
client := &http.Client{Transport: rt}
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, rt.Close())
|
||||
require.NoError(t, ln.Close())
|
||||
}()
|
||||
|
||||
// open first server and make single request to it
|
||||
fake1 := ln.Faker()
|
||||
stoppedServing1 := make(chan struct{})
|
||||
go func() {
|
||||
server1.ServeListener(fake1)
|
||||
close(stoppedServing1)
|
||||
}()
|
||||
|
||||
resp, err := client.Get("https://localhost:" + port + "/hello1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Hello, World 1!\n", string(body))
|
||||
|
||||
// open second server with same underlying listener
|
||||
fake2 := ln.Faker()
|
||||
stoppedServing2 := make(chan struct{})
|
||||
go func() {
|
||||
server2.ServeListener(fake2)
|
||||
close(stoppedServing2)
|
||||
}()
|
||||
|
||||
// Verify both servers are running by waiting a bit and checking channels aren't closed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
select {
|
||||
case <-stoppedServing1:
|
||||
t.Fatal("server1 stopped unexpectedly")
|
||||
case <-stoppedServing2:
|
||||
t.Fatal("server2 stopped unexpectedly")
|
||||
default:
|
||||
}
|
||||
|
||||
// now close first server
|
||||
require.NoError(t, server1.Close())
|
||||
select {
|
||||
case <-stoppedServing1:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for server1 to stop")
|
||||
}
|
||||
require.True(t, fake1.closed.Load())
|
||||
require.False(t, fake2.closed.Load())
|
||||
require.False(t, ln.listenerClosed)
|
||||
require.NoError(t, client.Transport.(*http3.Transport).Close())
|
||||
|
||||
// verify that new connections are being initiated from the second server now
|
||||
resp, err = client.Get("https://localhost:" + port + "/hello2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Hello, World 2!\n", string(body))
|
||||
|
||||
// close the other server - both the fake and the actual listeners must close now
|
||||
require.NoError(t, server2.Close())
|
||||
select {
|
||||
case <-stoppedServing2:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for server2 to stop")
|
||||
}
|
||||
require.True(t, fake2.closed.Load())
|
||||
require.True(t, ln.listenerClosed)
|
||||
}
|
||||
@@ -5,88 +5,94 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
sentHeaders []*logging.ShortHeader
|
||||
receivedHeaders []*logging.ShortHeader
|
||||
)
|
||||
func TestKeyUpdates(t *testing.T) {
|
||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
||||
t.Cleanup(func() { handshake.KeyUpdateInterval = origKeyUpdateInterval })
|
||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
||||
|
||||
func countKeyPhases() (sent, received int) {
|
||||
lastKeyPhase := protocol.KeyPhaseOne
|
||||
for _, hdr := range sentHeaders {
|
||||
if hdr.KeyPhase != lastKeyPhase {
|
||||
sent++
|
||||
lastKeyPhase = hdr.KeyPhase
|
||||
var sentHeaders []*logging.ShortHeader
|
||||
var receivedHeaders []*logging.ShortHeader
|
||||
|
||||
countKeyPhases := func() (sent, received int) {
|
||||
lastKeyPhase := protocol.KeyPhaseOne
|
||||
for _, hdr := range sentHeaders {
|
||||
if hdr.KeyPhase != lastKeyPhase {
|
||||
sent++
|
||||
lastKeyPhase = hdr.KeyPhase
|
||||
}
|
||||
}
|
||||
}
|
||||
lastKeyPhase = protocol.KeyPhaseOne
|
||||
for _, hdr := range receivedHeaders {
|
||||
if hdr.KeyPhase != lastKeyPhase {
|
||||
received++
|
||||
lastKeyPhase = hdr.KeyPhase
|
||||
lastKeyPhase = protocol.KeyPhaseOne
|
||||
for _, hdr := range receivedHeaders {
|
||||
if hdr.KeyPhase != lastKeyPhase {
|
||||
received++
|
||||
lastKeyPhase = hdr.KeyPhase
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer str.Close()
|
||||
if _, err := str.Write(PRDataLong); err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
close(serverErrChan)
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
|
||||
sentHeaders = append(sentHeaders, hdr)
|
||||
},
|
||||
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
|
||||
receivedHeaders = append(receivedHeaders, hdr)
|
||||
},
|
||||
}
|
||||
}}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRDataLong, data)
|
||||
require.NoError(t, conn.CloseWithError(0, ""))
|
||||
|
||||
require.NoError(t, <-serverErrChan)
|
||||
|
||||
keyPhasesSent, keyPhasesReceived := countKeyPhases()
|
||||
t.Logf("Used %d key phases on outgoing and %d key phases on incoming packets.", keyPhasesSent, keyPhasesReceived)
|
||||
require.Greater(t, keyPhasesReceived, 10)
|
||||
require.InDelta(t, keyPhasesSent, keyPhasesReceived, 2)
|
||||
}
|
||||
|
||||
var keyUpdateConnTracer = &logging.ConnectionTracer{
|
||||
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
|
||||
sentHeaders = append(sentHeaders, hdr)
|
||||
},
|
||||
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
|
||||
receivedHeaders = append(receivedHeaders, hdr)
|
||||
},
|
||||
}
|
||||
|
||||
var _ = Describe("Key Update tests", func() {
|
||||
It("downloads a large file", func() {
|
||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
||||
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
|
||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return keyUpdateConnTracer
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
|
||||
keyPhasesSent, keyPhasesReceived := countKeyPhases()
|
||||
fmt.Fprintf(GinkgoWriter, "Used %d key phases on outgoing and %d key phases on incoming packets.\n", keyPhasesSent, keyPhasesReceived)
|
||||
Expect(keyPhasesReceived).To(BeNumerically(">", 10))
|
||||
Expect(keyPhasesReceived).To(BeNumerically("~", keyPhasesSent, 2))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
@@ -18,455 +20,401 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/testutils"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("MITM test", func() {
|
||||
const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
|
||||
const mitmTestConnIDLen = 6
|
||||
|
||||
var (
|
||||
clientUDPConn net.PacketConn
|
||||
serverTransport, clientTransport *quic.Transport
|
||||
serverConn quic.Connection
|
||||
serverConfig *quic.Config
|
||||
)
|
||||
func getTransportsForMITMTest(t *testing.T) (serverTransport, clientTransport *quic.Transport) {
|
||||
serverUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
serverTransport = &quic.Transport{
|
||||
Conn: serverUDPConn,
|
||||
ConnectionIDLength: mitmTestConnIDLen,
|
||||
}
|
||||
addTracer(serverTransport)
|
||||
t.Cleanup(func() { serverTransport.Close() })
|
||||
|
||||
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverTransport = &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
addTracer(serverTransport)
|
||||
if forceAddressValidation {
|
||||
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||
}
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
var err error
|
||||
serverConn, err = ln.Accept(context.Background())
|
||||
clientUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
clientTransport = &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: mitmTestConnIDLen,
|
||||
}
|
||||
addTracer(clientTransport)
|
||||
t.Cleanup(func() { clientTransport.Close() })
|
||||
|
||||
return serverTransport, clientTransport
|
||||
}
|
||||
|
||||
func TestMITMInjectRandomPackets(t *testing.T) {
|
||||
t.Run("towards the server", func(t *testing.T) {
|
||||
testMITMInjectRandomPackets(t, quicproxy.DirectionIncoming)
|
||||
})
|
||||
|
||||
t.Run("towards the client", func(t *testing.T) {
|
||||
testMITMInjectRandomPackets(t, quicproxy.DirectionOutgoing)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMITMDuplicatePackets(t *testing.T) {
|
||||
t.Run("towards the server", func(t *testing.T) {
|
||||
testMITMDuplicatePackets(t, quicproxy.DirectionIncoming)
|
||||
})
|
||||
|
||||
t.Run("towards the client", func(t *testing.T) {
|
||||
testMITMDuplicatePackets(t, quicproxy.DirectionOutgoing)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMITInjectCorruptedPackets(t *testing.T) {
|
||||
t.Run("towards the server", func(t *testing.T) {
|
||||
testMITMInjectCorruptedPackets(t, quicproxy.DirectionIncoming)
|
||||
})
|
||||
|
||||
t.Run("towards the client", func(t *testing.T) {
|
||||
testMITMInjectCorruptedPackets(t, quicproxy.DirectionOutgoing)
|
||||
})
|
||||
}
|
||||
|
||||
func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) {
|
||||
createRandomPacketOfSameType := func(b []byte) []byte {
|
||||
if wire.IsLongHeaderPacket(b[0]) {
|
||||
hdr, _, _, err := wire.ParsePacket(b)
|
||||
if err != nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
str, err := serverConn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: delayCb,
|
||||
DropPacket: dropCb,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return proxy.LocalPort(), func() {
|
||||
proxy.Close()
|
||||
ln.Close()
|
||||
serverTransport.Close()
|
||||
<-done
|
||||
replyHdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: hdr.DestConnectionID,
|
||||
SrcConnectionID: hdr.SrcConnectionID,
|
||||
Type: hdr.Type,
|
||||
Version: hdr.Version,
|
||||
},
|
||||
PacketNumber: protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
|
||||
PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
|
||||
}
|
||||
payloadLen := rand.Int31n(100)
|
||||
replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
|
||||
data, err := replyHdr.Append(nil, hdr.Version)
|
||||
if err != nil {
|
||||
panic("failed to append header: " + err.Error())
|
||||
}
|
||||
r := make([]byte, payloadLen)
|
||||
rand.Read(r)
|
||||
return append(data, r...)
|
||||
}
|
||||
// short header packet
|
||||
connID, err := wire.ParseConnectionID(b, mitmTestConnIDLen)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, pn, pnLen, _, err := wire.ParseShortHeader(b, mitmTestConnIDLen)
|
||||
if err != nil && !errors.Is(err, wire.ErrInvalidReservedBits) { // normally, ParseShortHeader is called after decrypting the header
|
||||
panic("failed to parse short header: " + err.Error())
|
||||
}
|
||||
data, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
payloadLen := rand.Int31n(100)
|
||||
r := make([]byte, payloadLen)
|
||||
rand.Read(r)
|
||||
return append(data, r...)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig = getQuicConfig(nil)
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientUDPConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientTransport = &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
|
||||
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
|
||||
if dir != direction {
|
||||
return false
|
||||
}
|
||||
addTracer(clientTransport)
|
||||
})
|
||||
|
||||
Context("unsuccessful attacks", func() {
|
||||
AfterEach(func() {
|
||||
Eventually(serverConn.Context().Done()).Should(BeClosed())
|
||||
// Test shutdown is tricky due to the proxy. Just wait for a bit.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
Expect(clientUDPConn.Close()).To(Succeed())
|
||||
Expect(clientTransport.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
Context("injecting invalid packets", func() {
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
|
||||
defer GinkgoRecover()
|
||||
const numPackets = 10
|
||||
ticker := time.NewTicker(rtt / numPackets)
|
||||
defer ticker.Stop()
|
||||
|
||||
if wire.IsLongHeaderPacket(raw[0]) {
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
replyHdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: hdr.DestConnectionID,
|
||||
SrcConnectionID: hdr.SrcConnectionID,
|
||||
Type: hdr.Type,
|
||||
Version: hdr.Version,
|
||||
},
|
||||
PacketNumber: protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)),
|
||||
PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1),
|
||||
}
|
||||
|
||||
for i := 0; i < numPackets; i++ {
|
||||
payloadLen := rand.Int31n(100)
|
||||
replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1))
|
||||
b, err := replyHdr.Append(nil, hdr.Version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
r := make([]byte, payloadLen)
|
||||
rand.Read(r)
|
||||
b = append(b, r...)
|
||||
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
|
||||
return
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
} else {
|
||||
connID, err := wire.ParseConnectionID(raw, connIDLen)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
|
||||
if err != nil { // normally, ParseShortHeader is called after decrypting the header
|
||||
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
|
||||
}
|
||||
for i := 0; i < numPackets; i++ {
|
||||
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2)))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
payloadLen := rand.Int31n(100)
|
||||
r := make([]byte, payloadLen)
|
||||
rand.Read(r)
|
||||
b = append(b, r...)
|
||||
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
|
||||
return
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(rtt / 10)
|
||||
defer ticker.Stop()
|
||||
for i := 0; i < 10; i++ {
|
||||
switch direction {
|
||||
case quicproxy.DirectionIncoming:
|
||||
clientTransport.WriteTo(createRandomPacketOfSameType(b), serverTransport.Conn.LocalAddr())
|
||||
case quicproxy.DirectionOutgoing:
|
||||
serverTransport.WriteTo(createRandomPacketOfSameType(b), clientTransport.Conn.LocalAddr())
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}()
|
||||
return false
|
||||
}
|
||||
|
||||
runTest := func(delayCb quicproxy.DelayCallback) {
|
||||
proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
}
|
||||
|
||||
func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
|
||||
if dir != direction {
|
||||
return false
|
||||
}
|
||||
switch direction {
|
||||
case quicproxy.DirectionIncoming:
|
||||
clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr())
|
||||
case quicproxy.DirectionOutgoing:
|
||||
serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
}
|
||||
|
||||
func testMITMInjectCorruptedPackets(t *testing.T, direction quicproxy.Direction) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
var numCorrupted atomic.Int32
|
||||
const interval = 4
|
||||
dropCallback := func(dir quicproxy.Direction, b []byte) bool {
|
||||
if dir != direction {
|
||||
return false
|
||||
}
|
||||
if rand.Intn(interval) == 0 {
|
||||
numCorrupted.Add(1)
|
||||
pos := rand.Intn(len(b))
|
||||
b[pos] = byte(rand.Intn(256))
|
||||
return true
|
||||
}
|
||||
switch direction {
|
||||
case quicproxy.DirectionIncoming:
|
||||
clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr())
|
||||
case quicproxy.DirectionOutgoing:
|
||||
serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
runMITMTest(t, serverTransport, clientTransport, direction, rtt, dropCallback)
|
||||
require.NotZero(t, int(numCorrupted.Load()))
|
||||
}
|
||||
|
||||
func runMITMTest(t *testing.T,
|
||||
serverTransport, clientTransport *quic.Transport,
|
||||
direction quicproxy.Direction,
|
||||
rtt time.Duration,
|
||||
dropCallback quicproxy.DropCallback,
|
||||
) {
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(dir quicproxy.Direction, b []byte) time.Duration { return rtt / 2 },
|
||||
DropPacket: dropCallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
if _, err := serverStr.Write(PRData); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
serverStr.Close()
|
||||
}()
|
||||
require.NoError(t, <-errChan)
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
require.NoError(t, conn.CloseWithError(0, ""))
|
||||
}
|
||||
|
||||
func TestMITMForgedVersionNegotiationPacket(t *testing.T) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
const supportedVersion protocol.Version = 42
|
||||
|
||||
var once sync.Once
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir != quicproxy.DirectionIncoming {
|
||||
return rtt / 2
|
||||
}
|
||||
once.Do(func() {
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
if err != nil {
|
||||
panic("failed to parse packet: " + err.Error())
|
||||
}
|
||||
// create fake version negotiation packet with a fake supported version
|
||||
packet := wire.ComposeVersionNegotiation(
|
||||
protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
|
||||
protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
|
||||
[]protocol.Version{supportedVersion},
|
||||
)
|
||||
if _, err := serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr()); err != nil {
|
||||
panic("failed to write packet: " + err.Error())
|
||||
}
|
||||
})
|
||||
return rtt / 2
|
||||
}
|
||||
|
||||
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
|
||||
var vnErr *quic.VersionNegotiationError
|
||||
require.ErrorAs(t, err, &vnErr)
|
||||
require.Contains(t, vnErr.Theirs, supportedVersion) // might contain greased versions
|
||||
}
|
||||
|
||||
// times out, because client doesn't accept subsequent real retry packets from server
|
||||
// as it has already accepted a retry.
|
||||
// TODO: determine behavior when server does not send Retry packets
|
||||
func TestMITMForgedRetryPacket(t *testing.T) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
var once sync.Once
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
if err != nil {
|
||||
panic("failed to parse packet: " + err.Error())
|
||||
}
|
||||
if dir == quicproxy.DirectionIncoming && hdr.Type == protocol.PacketTypeInitial {
|
||||
once.Do(func() {
|
||||
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
|
||||
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
|
||||
if _, err := serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr()); err != nil {
|
||||
panic("failed to write packet: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
|
||||
var nerr net.Error
|
||||
require.ErrorAs(t, err, &nerr)
|
||||
require.True(t, nerr.Timeout())
|
||||
}
|
||||
|
||||
func TestMITMForgedInitialPacket(t *testing.T) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
var once sync.Once
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
if err != nil {
|
||||
panic("failed to parse packet: " + err.Error())
|
||||
}
|
||||
if hdr.Type != protocol.PacketTypeInitial {
|
||||
return 0
|
||||
}
|
||||
once.Do(func() {
|
||||
initialPacket := testutils.ComposeInitialPacket(
|
||||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
hdr.DestConnectionID,
|
||||
nil,
|
||||
nil,
|
||||
protocol.PerspectiveServer,
|
||||
hdr.Version,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
}
|
||||
|
||||
It("downloads a message when the packets are injected towards the server", func() {
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
defer GinkgoRecover()
|
||||
go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
|
||||
}
|
||||
return rtt / 2
|
||||
if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil {
|
||||
panic("failed to write packet: " + err.Error())
|
||||
}
|
||||
runTest(delayCb)
|
||||
})
|
||||
|
||||
It("downloads a message when the packets are injected towards the client", func() {
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
defer GinkgoRecover()
|
||||
go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
runTest(delayCb)
|
||||
})
|
||||
})
|
||||
|
||||
runTest := func(dropCb quicproxy.DropCallback) {
|
||||
proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
|
||||
var nerr net.Error
|
||||
require.ErrorAs(t, err, &nerr)
|
||||
require.True(t, nerr.Timeout())
|
||||
}
|
||||
|
||||
Context("duplicating packets", func() {
|
||||
It("downloads a message when packets are duplicated towards the server", func() {
|
||||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return false
|
||||
func TestMITMForgedInitialPacketWithAck(t *testing.T) {
|
||||
serverTransport, clientTransport := getTransportsForMITMTest(t)
|
||||
rtt := scaleDuration(10 * time.Millisecond)
|
||||
|
||||
var once sync.Once
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
if err != nil {
|
||||
panic("failed to parse packet: " + err.Error())
|
||||
}
|
||||
if hdr.Type != protocol.PacketTypeInitial {
|
||||
return 0
|
||||
}
|
||||
once.Do(func() {
|
||||
// Fake Initial with ACK for packet 2 (unsent)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
initialPacket := testutils.ComposeInitialPacket(
|
||||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
hdr.DestConnectionID,
|
||||
nil,
|
||||
[]wire.Frame{ack},
|
||||
protocol.PerspectiveServer,
|
||||
hdr.Version,
|
||||
)
|
||||
if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil {
|
||||
panic("failed to write packet: " + err.Error())
|
||||
}
|
||||
runTest(dropCb)
|
||||
})
|
||||
|
||||
It("downloads a message when packets are duplicated towards the client", func() {
|
||||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return false
|
||||
}
|
||||
runTest(dropCb)
|
||||
})
|
||||
})
|
||||
|
||||
Context("corrupting packets", func() {
|
||||
const idleTimeout = time.Second
|
||||
|
||||
var numCorrupted, numPackets atomic.Int32
|
||||
|
||||
BeforeEach(func() {
|
||||
numCorrupted.Store(0)
|
||||
numPackets.Store(0)
|
||||
serverConfig.MaxIdleTimeout = idleTimeout
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
num := numCorrupted.Load()
|
||||
fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load())
|
||||
Expect(num).To(BeNumerically(">=", 1))
|
||||
// If the packet containing the CONNECTION_CLOSE is corrupted,
|
||||
// we have to wait for the connection to time out.
|
||||
Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("downloads a message when packet are corrupted towards the server", func() {
|
||||
const interval = 4 // corrupt every 4th packet (stochastically)
|
||||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
numPackets.Add(1)
|
||||
if rand.Intn(interval) == 0 {
|
||||
pos := rand.Intn(len(raw))
|
||||
raw[pos] = byte(rand.Intn(256))
|
||||
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
numCorrupted.Add(1)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
runTest(dropCb)
|
||||
})
|
||||
|
||||
It("downloads a message when packet are corrupted towards the client", func() {
|
||||
const interval = 10 // corrupt every 10th packet (stochastically)
|
||||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
numPackets.Add(1)
|
||||
if rand.Intn(interval) == 0 {
|
||||
pos := rand.Intn(len(raw))
|
||||
raw[pos] = byte(rand.Intn(256))
|
||||
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
numCorrupted.Add(1)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
runTest(dropCb)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("successful injection attacks", func() {
|
||||
// These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake
|
||||
// finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply
|
||||
// that arrives before the real reply can tear down the connection in multiple ways.
|
||||
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
runTest := func(proxyPort int) (closeFn func(), err error) {
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
|
||||
)
|
||||
return func() { clientTransport.Close() }, err
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
|
||||
// fails immediately because client connection closes when it can't find compatible version
|
||||
It("fails when a forged version negotiation packet is sent to client", func() {
|
||||
done := make(chan struct{})
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
defer GinkgoRecover()
|
||||
err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb)
|
||||
var transportErr *quic.TransportError
|
||||
require.ErrorAs(t, err, &transportErr)
|
||||
require.Equal(t, quic.ProtocolViolation, transportErr.ErrorCode)
|
||||
require.Contains(t, transportErr.ErrorMessage, "received ACK for an unsent packet")
|
||||
}
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
func runMITMTestSuccessful(t *testing.T, serverTransport, clientTransport *quic.Transport, delayCb quicproxy.DelayCallback) error {
|
||||
t.Helper()
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
if hdr.Type != protocol.PacketTypeInitial {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Create fake version negotiation packet with no supported versions
|
||||
versions := []protocol.Version{}
|
||||
packet := wire.ComposeVersionNegotiation(
|
||||
protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()),
|
||||
protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()),
|
||||
versions,
|
||||
)
|
||||
|
||||
// Send the packet
|
||||
_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
vnErr := &quic.VersionNegotiationError{}
|
||||
Expect(errors.As(err, &vnErr)).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
// times out, because client doesn't accept subsequent real retry packets from server
|
||||
// as it has already accepted a retry.
|
||||
// TODO: determine behavior when server does not send Retry packets
|
||||
It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
|
||||
var initialPacketIntercepted bool
|
||||
done := make(chan struct{})
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
if hdr.Type != protocol.PacketTypeInitial {
|
||||
return 0
|
||||
}
|
||||
|
||||
initialPacketIntercepted = true
|
||||
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
|
||||
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
|
||||
|
||||
_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(net.Error).Timeout()).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
// times out, because client doesn't accept real retry packets from server because
|
||||
// it has already accepted an initial.
|
||||
// TODO: determine behavior when server does not send Retry packets
|
||||
It("fails when a forged initial packet is sent to client", func() {
|
||||
done := make(chan struct{})
|
||||
var injected bool
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
defer GinkgoRecover()
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type != protocol.PacketTypeInitial || injected {
|
||||
return 0
|
||||
}
|
||||
defer close(done)
|
||||
injected = true
|
||||
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, nil, protocol.PerspectiveServer, hdr.Version)
|
||||
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt
|
||||
}
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(net.Error).Timeout()).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
// client connection closes immediately on receiving ack for unsent packet
|
||||
It("fails when a forged initial packet with ack for unsent packet is sent to client", func() {
|
||||
done := make(chan struct{})
|
||||
var injected bool
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
defer GinkgoRecover()
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type != protocol.PacketTypeInitial || injected {
|
||||
return 0
|
||||
}
|
||||
defer close(done)
|
||||
injected = true
|
||||
// Fake Initial with ACK for packet 2 (unsent)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version)
|
||||
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt
|
||||
}
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation))
|
||||
Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: delayCb,
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
|
||||
defer cancel()
|
||||
_, err = clientTransport.Dial(
|
||||
ctx,
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.Error(t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
@@ -13,158 +15,188 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("DPLPMTUD", func() {
|
||||
It("discovers the MTU", func() {
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
const mtu = 1400
|
||||
func TestInitialPacketSize(t *testing.T) {
|
||||
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
client, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
InitialPacketSize: 1234,
|
||||
DisablePathMTUDiscovery: true,
|
||||
EnableDatagrams: true,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = io.Copy(str, str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
}()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
quic.Dial(ctx, client, server.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{
|
||||
InitialPacketSize: 1337,
|
||||
}))
|
||||
}()
|
||||
|
||||
var mx sync.Mutex
|
||||
var maxPacketSizeServer int
|
||||
var clientPacketSizes []int
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
if len(packet) > mtu {
|
||||
return true
|
||||
}
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
switch dir {
|
||||
case quicproxy.DirectionIncoming:
|
||||
clientPacketSizes = append(clientPacketSizes, len(packet))
|
||||
case quicproxy.DirectionOutgoing:
|
||||
if len(packet) > maxPacketSizeServer {
|
||||
maxPacketSizeServer = len(packet)
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
buf := make([]byte, 2000)
|
||||
n, _, err := server.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1337, n)
|
||||
|
||||
// Make sure to use v4-only socket here.
|
||||
// We can't reliably set the DF bit on dual-stack sockets on macOS before Sequoia (macOS 15).
|
||||
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{Conn: udpConn}
|
||||
defer tr.Close()
|
||||
var mtus []logging.ByteCount
|
||||
mtuTracer := &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, _ bool) {
|
||||
mtus = append(mtus, mtu)
|
||||
},
|
||||
cancel()
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestPathMTUDiscovery(t *testing.T) {
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
const mtu = 1400
|
||||
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
InitialPacketSize: 1234,
|
||||
DisablePathMTUDiscovery: true,
|
||||
EnableDatagrams: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
proxy.LocalAddr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
InitialPacketSize: protocol.MinInitialPacketSize,
|
||||
EnableDatagrams: true,
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return mtuTracer
|
||||
},
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
defer GinkgoRecover()
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRDataLong))
|
||||
}()
|
||||
err = conn.SendDatagram(make([]byte, 2000))
|
||||
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
|
||||
initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
|
||||
_, err = str.Write(PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
Eventually(done, 20*time.Second).Should(BeClosed())
|
||||
err = conn.SendDatagram(make([]byte, 2000))
|
||||
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
|
||||
finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer str.Close()
|
||||
if _, err := io.Copy(str, str); err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
Expect(mtus).ToNot(BeEmpty())
|
||||
maxPacketSizeClient := int(mtus[len(mtus)-1])
|
||||
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
|
||||
fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize)
|
||||
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
|
||||
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
|
||||
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
|
||||
Expect(initialMaxDatagramSize).To(BeNumerically(">=", protocol.MinInitialPacketSize-maxDiff))
|
||||
Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff))
|
||||
// MTU discovery was disabled on the server side
|
||||
Expect(maxPacketSizeServer).To(Equal(1234))
|
||||
|
||||
var numPacketsLargerThanDiscoveredMTU int
|
||||
for _, s := range clientPacketSizes {
|
||||
if s > maxPacketSizeClient {
|
||||
numPacketsLargerThanDiscoveredMTU++
|
||||
var mx sync.Mutex
|
||||
var maxPacketSizeServer int
|
||||
var clientPacketSizes []int
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
if len(packet) > mtu {
|
||||
return true
|
||||
}
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
switch dir {
|
||||
case quicproxy.DirectionIncoming:
|
||||
clientPacketSizes = append(clientPacketSizes, len(packet))
|
||||
case quicproxy.DirectionOutgoing:
|
||||
if len(packet) > maxPacketSizeServer {
|
||||
maxPacketSizeServer = len(packet)
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
// Make sure to use v4-only socket here.
|
||||
// We can't reliably set the DF bit on dual-stack sockets on older versions of macOS (before Sequoia).
|
||||
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{Conn: udpConn}
|
||||
defer tr.Close()
|
||||
|
||||
var mtus []logging.ByteCount
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
proxy.LocalAddr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
InitialPacketSize: protocol.MinInitialPacketSize,
|
||||
EnableDatagrams: true,
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, _ bool) { mtus = append(mtus, mtu) },
|
||||
}
|
||||
},
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
err = conn.SendDatagram(make([]byte, 2000))
|
||||
require.Error(t, err)
|
||||
var datagramErr *quic.DatagramTooLargeError
|
||||
require.ErrorAs(t, err, &datagramErr)
|
||||
initialMaxDatagramSize := datagramErr.MaxDatagramPayloadSize
|
||||
|
||||
str, err := conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
data, err := io.ReadAll(str)
|
||||
if err != nil {
|
||||
clientErrChan <- err
|
||||
return
|
||||
}
|
||||
// The client shouldn't have sent any packets larger than the MTU it discovered,
|
||||
// except for at most one MTU probe packet.
|
||||
Expect(numPacketsLargerThanDiscoveredMTU).To(BeNumerically("<=", 1))
|
||||
})
|
||||
if !bytes.Equal(data, PRDataLong) {
|
||||
clientErrChan <- fmt.Errorf("echoed data doesn't match: %x", data)
|
||||
return
|
||||
}
|
||||
clientErrChan <- nil
|
||||
}()
|
||||
|
||||
It("uses the initial packet size", func() {
|
||||
c, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.Close()
|
||||
_, err = str.Write(PRDataLong)
|
||||
require.NoError(t, err)
|
||||
str.Close()
|
||||
|
||||
cconn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cconn.Close()
|
||||
select {
|
||||
case err := <-clientErrChan:
|
||||
require.NoError(t, err)
|
||||
case err := <-serverErrChan:
|
||||
t.Fatalf("server error: %v", err)
|
||||
case <-time.After(20 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
quic.Dial(ctx, cconn, c.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{InitialPacketSize: 1337}))
|
||||
}()
|
||||
err = conn.SendDatagram(make([]byte, 2000))
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &datagramErr)
|
||||
finalMaxDatagramSize := datagramErr.MaxDatagramPayloadSize
|
||||
|
||||
b := make([]byte, 2000)
|
||||
n, _, err := c.ReadFrom(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(1337))
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
require.NotEmpty(t, mtus)
|
||||
|
||||
maxPacketSizeClient := int(mtus[len(mtus)-1])
|
||||
t.Logf("max client packet size: %d, MTU: %d", maxPacketSizeClient, mtu)
|
||||
t.Logf("max datagram size: initial: %d, final: %d", initialMaxDatagramSize, finalMaxDatagramSize)
|
||||
t.Logf("max server packet size: %d, MTU: %d", maxPacketSizeServer, mtu)
|
||||
|
||||
require.GreaterOrEqual(t, maxPacketSizeClient, mtu-25)
|
||||
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
|
||||
require.GreaterOrEqual(t, int(initialMaxDatagramSize), protocol.MinInitialPacketSize-maxDiff)
|
||||
require.GreaterOrEqual(t, int(finalMaxDatagramSize), maxPacketSizeClient-maxDiff)
|
||||
// MTU discovery was disabled on the server side
|
||||
require.Equal(t, 1234, maxPacketSizeServer)
|
||||
|
||||
var numPacketsLargerThanDiscoveredMTU int
|
||||
for _, s := range clientPacketSizes {
|
||||
if s > maxPacketSizeClient {
|
||||
numPacketsLargerThanDiscoveredMTU++
|
||||
}
|
||||
}
|
||||
// The client shouldn't have sent any packets larger than the MTU it discovered,
|
||||
// except for at most one MTU probe packet.
|
||||
require.LessOrEqual(t, numPacketsLargerThanDiscoveredMTU, 1)
|
||||
}
|
||||
|
||||
@@ -1,296 +1,323 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Multiplexing", func() {
|
||||
runServer := func(ln *quic.Listener) {
|
||||
func runMultiplexTestServer(t *testing.T, ln *quic.Listener) {
|
||||
t.Helper()
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRData)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() { conn.CloseWithError(0, "") })
|
||||
}
|
||||
}
|
||||
|
||||
func dialAndReceiveData(tr *quic.Transport, addr net.Addr) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
conn, err := tr.Dial(ctx, addr, getTLSClientConfig(), getQuicConfig(nil))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error dialing: %w", err)
|
||||
}
|
||||
str, err := conn.AcceptUniStream(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error accepting stream: %w", err)
|
||||
}
|
||||
data, err := io.ReadAll(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading data: %w", err)
|
||||
}
|
||||
if !bytes.Equal(data, PRData) {
|
||||
return fmt.Errorf("data mismatch: got %q, expected %q", data, PRData)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMultiplexesConnectionsToSameServer(t *testing.T) {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
go runMultiplexTestServer(t, server)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
errChan1 := make(chan error, 1)
|
||||
go func() { errChan1 <- dialAndReceiveData(tr, server.Addr()) }()
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() { errChan2 <- dialAndReceiveData(tr, server.Addr()) }()
|
||||
|
||||
select {
|
||||
case err := <-errChan1:
|
||||
require.NoError(t, err, "error dialing server 1")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for done1 to close")
|
||||
}
|
||||
select {
|
||||
case err := <-errChan2:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for done2 to close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplexingToDifferentServers(t *testing.T) {
|
||||
server1, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server1.Close()
|
||||
go runMultiplexTestServer(t, server1)
|
||||
|
||||
server2, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server2.Close()
|
||||
go runMultiplexTestServer(t, server2)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
errChan1 := make(chan error, 1)
|
||||
go func() { errChan1 <- dialAndReceiveData(tr, server1.Addr()) }()
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() { errChan2 <- dialAndReceiveData(tr, server2.Addr()) }()
|
||||
|
||||
select {
|
||||
case err := <-errChan1:
|
||||
require.NoError(t, err, "error dialing server 1")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for done1 to close")
|
||||
}
|
||||
select {
|
||||
case err := <-errChan2:
|
||||
require.NoError(t, err, "error dialing server 2")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for done2 to close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplexingConnectToSelf(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
go runMultiplexTestServer(t, server)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- dialAndReceiveData(tr, server.Addr()) }()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err, "error dialing server")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for connection to close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplexingServerAndClientOnSameConn(t *testing.T) {
|
||||
if runtime.GOOS == "linux" {
|
||||
t.Skip("This test requires setting of iptables rules on Linux, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.")
|
||||
}
|
||||
|
||||
dial := func(tr *quic.Transport, addr net.Addr) {
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
addr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
addTracer(tr1)
|
||||
defer tr1.Close()
|
||||
server1, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server1.Close()
|
||||
|
||||
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
addTracer(tr2)
|
||||
defer tr2.Close()
|
||||
server2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server2.Close()
|
||||
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done1)
|
||||
dialAndReceiveData(tr2, server1.Addr())
|
||||
}()
|
||||
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done2)
|
||||
dialAndReceiveData(tr1, server2.Addr())
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done1:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("timeout waiting for done1 to close")
|
||||
}
|
||||
select {
|
||||
case <-done2:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for done2 to close")
|
||||
}
|
||||
}
|
||||
|
||||
Context("multiplexing clients on the same conn", func() {
|
||||
getListener := func() *quic.Listener {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return ln
|
||||
func TestMultiplexingNonQUICPackets(t *testing.T) {
|
||||
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
defer tr1.Close()
|
||||
addTracer(tr1)
|
||||
|
||||
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
defer tr2.Close()
|
||||
addTracer(tr2)
|
||||
|
||||
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
type nonQUICPacket struct {
|
||||
b []byte
|
||||
addr net.Addr
|
||||
err error
|
||||
}
|
||||
done := make(chan struct{})
|
||||
var rcvdPackets []nonQUICPacket
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// start receiving non-QUIC packets
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
rcvdPackets = append(rcvdPackets, nonQUICPacket{b: b[:n], addr: addr, err: err})
|
||||
}
|
||||
}()
|
||||
|
||||
It("multiplexes connections to the same server", func() {
|
||||
server := getListener()
|
||||
runServer(server)
|
||||
defer server.Close()
|
||||
// send a non-QUIC packet every 100µs
|
||||
const packetLen = 128
|
||||
var sentPackets atomic.Int64
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Millisecond / 10)
|
||||
defer ticker.Stop()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr, server.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr, server.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if debugLog() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("multiplexes connections to different servers", func() {
|
||||
server1 := getListener()
|
||||
runServer(server1)
|
||||
defer server1.Close()
|
||||
server2 := getListener()
|
||||
runServer(server2)
|
||||
defer server2.Close()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if debugLog() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("multiplexing server and client on the same conn", func() {
|
||||
It("connects to itself", func() {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
addTracer(tr)
|
||||
server, err := tr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr, server.Addr())
|
||||
close(done)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if debugLog() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done, timeout).Should(BeClosed())
|
||||
})
|
||||
|
||||
// This test would require setting of iptables rules, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.
|
||||
if runtime.GOOS != "linux" {
|
||||
It("runs a server and client on the same conn", func() {
|
||||
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
addTracer(tr1)
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
addTracer(tr2)
|
||||
|
||||
server1, err := tr1.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server1)
|
||||
defer server1.Close()
|
||||
|
||||
server2, err := tr2.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server2)
|
||||
defer server2.Close()
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr2, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(tr1, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if debugLog() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
It("sends and receives non-QUIC packets", func() {
|
||||
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
addTracer(tr1)
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
addTracer(tr2)
|
||||
|
||||
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server)
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
var sentPackets, rcvdPackets atomic.Int64
|
||||
const packetLen = 128
|
||||
// send a non-QUIC packet every 100µs
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
ticker := time.NewTicker(time.Millisecond / 10)
|
||||
defer ticker.Stop()
|
||||
var wroteFirstPacket bool
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
var wroteFirstPacket bool
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b := make([]byte, packetLen)
|
||||
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
|
||||
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
||||
// The first sendmsg call on a new UDP socket sometimes errors on Linux.
|
||||
// It's not clear why this happens.
|
||||
// See https://github.com/golang/go/issues/63322.
|
||||
if err != nil && !wroteFirstPacket && isPermissionError(err) {
|
||||
if err != nil && !wroteFirstPacket && runtime.GOOS == "linux" && isPermissionError(err) {
|
||||
_, err = tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
||||
}
|
||||
if ctx.Err() != nil { // ctx canceled while Read was executing
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sentPackets.Add(1)
|
||||
wroteFirstPacket = true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// receive and count non-QUIC packets
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
|
||||
if err != nil {
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
return
|
||||
}
|
||||
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
|
||||
Expect(n).To(Equal(packetLen))
|
||||
rcvdPackets.Add(1)
|
||||
}
|
||||
}()
|
||||
dial(tr2, server.Addr())
|
||||
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
|
||||
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
|
||||
})
|
||||
})
|
||||
conn, err := tr2.Dial(
|
||||
context.Background(),
|
||||
server.Addr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
defer serverStr.Close()
|
||||
_, _ = serverStr.Write(PRData)
|
||||
}()
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
|
||||
// stop sending non-QUIC packets
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Fatalf("error sending non-QUIC packets: %v", err)
|
||||
case <-done:
|
||||
}
|
||||
|
||||
sent := int(sentPackets.Load())
|
||||
require.Greater(t, sent, 10, "not enough non-QUIC packets sent: %d", sent)
|
||||
rcvd := len(rcvdPackets)
|
||||
minExpected := sent * 4 / 5
|
||||
require.GreaterOrEqual(t, rcvd, minExpected, "not enough packets received. got: %d, expected at least: %d", rcvd, minExpected)
|
||||
|
||||
for _, p := range rcvdPackets {
|
||||
require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address")
|
||||
require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length")
|
||||
require.NoError(t, p.err, "error receiving non-QUIC packet")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,117 +4,120 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Packetization", func() {
|
||||
// In this test, the client sends 100 small messages. The server echoes these messages.
|
||||
// This means that every endpoint will send 100 ack-eliciting packets in short succession.
|
||||
// This test then tests that no more than 110 packets are sent in every direction, making sure that ACK are bundled.
|
||||
It("bundles ACKs", func() {
|
||||
const numMsg = 100
|
||||
func TestACKBundling(t *testing.T) {
|
||||
const numMsg = 100
|
||||
|
||||
serverCounter, serverTracer := newPacketTracer()
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(serverTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
serverCounter, serverTracer := newPacketTracer()
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(serverTracer),
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr,
|
||||
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
|
||||
return 5 * time.Millisecond
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr,
|
||||
DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration {
|
||||
return 5 * time.Millisecond
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
clientCounter, clientTracer := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(clientTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
clientCounter, clientTracer := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(clientTracer),
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 1)
|
||||
// Echo every byte received from the client.
|
||||
for {
|
||||
if _, err := str.Read(b); err != nil {
|
||||
break
|
||||
}
|
||||
_, err = str.Write(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
||||
str, err := conn.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 1)
|
||||
// Send numMsg 1-byte messages.
|
||||
for i := 0; i < numMsg; i++ {
|
||||
_, err = str.Write([]byte{uint8(i)})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b[0]).To(Equal(uint8(i)))
|
||||
}
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
|
||||
countBundledPackets := func(packets []shortHeaderPacket) (numBundled int) {
|
||||
for _, p := range packets {
|
||||
var hasAck, hasStreamFrame bool
|
||||
for _, f := range p.frames {
|
||||
switch f.(type) {
|
||||
case *logging.AckFrame:
|
||||
hasAck = true
|
||||
case *logging.StreamFrame:
|
||||
hasStreamFrame = true
|
||||
}
|
||||
}
|
||||
if hasAck && hasStreamFrame {
|
||||
numBundled++
|
||||
}
|
||||
}
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
serverErrChan <- fmt.Errorf("accept failed: %w", err)
|
||||
return
|
||||
}
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
serverErrChan <- fmt.Errorf("accept stream failed: %w", err)
|
||||
return
|
||||
}
|
||||
b := make([]byte, 1)
|
||||
// Echo every byte received from the client.
|
||||
for {
|
||||
if _, err := str.Read(b); err != nil {
|
||||
break
|
||||
}
|
||||
_, err = str.Write(b)
|
||||
if err != nil {
|
||||
serverErrChan <- fmt.Errorf("write failed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
|
||||
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
|
||||
fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg)
|
||||
fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg)
|
||||
Expect(numBundledIncoming).To(And(
|
||||
BeNumerically("<=", numMsg),
|
||||
BeNumerically(">", numMsg*9/10),
|
||||
))
|
||||
Expect(numBundledOutgoing).To(And(
|
||||
BeNumerically("<=", numMsg),
|
||||
BeNumerically(">", numMsg*9/10),
|
||||
))
|
||||
})
|
||||
})
|
||||
str, err := conn.OpenStreamSync(context.Background())
|
||||
require.NoError(t, err)
|
||||
b := make([]byte, 1)
|
||||
// Send numMsg 1-byte messages.
|
||||
for i := 0; i < numMsg; i++ {
|
||||
_, err = str.Write([]byte{uint8(i)})
|
||||
require.NoError(t, err)
|
||||
_, err = str.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(i), b[0])
|
||||
}
|
||||
require.NoError(t, conn.CloseWithError(0, ""))
|
||||
require.NoError(t, <-serverErrChan)
|
||||
|
||||
countBundledPackets := func(packets []shortHeaderPacket) (numBundled int) {
|
||||
for _, p := range packets {
|
||||
var hasAck, hasStreamFrame bool
|
||||
for _, f := range p.frames {
|
||||
switch f.(type) {
|
||||
case *logging.AckFrame:
|
||||
hasAck = true
|
||||
case *logging.StreamFrame:
|
||||
hasStreamFrame = true
|
||||
}
|
||||
}
|
||||
if hasAck && hasStreamFrame {
|
||||
numBundled++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
|
||||
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
|
||||
t.Logf("bundled incoming packets: %d / %d", numBundledIncoming, numMsg)
|
||||
t.Logf("bundled outgoing packets: %d / %d", numBundledOutgoing, numMsg)
|
||||
|
||||
require.LessOrEqual(t, numBundledIncoming, numMsg)
|
||||
require.Greater(t, numBundledIncoming, numMsg*9/10)
|
||||
require.LessOrEqual(t, numBundledOutgoing, numMsg)
|
||||
require.Greater(t, numBundledOutgoing, numMsg*9/10)
|
||||
}
|
||||
|
||||
@@ -5,87 +5,77 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("qlog dir tests", Serial, func() {
|
||||
var originalQlogDirValue string
|
||||
var tempTestDirPath string
|
||||
|
||||
BeforeEach(func() {
|
||||
originalQlogDirValue = os.Getenv("QLOGDIR")
|
||||
var err error
|
||||
tempTestDirPath, err = os.MkdirTemp("", "temp_test_dir")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
func TestQlogDirEnvironmentVariable(t *testing.T) {
|
||||
originalQlogDirValue := os.Getenv("QLOGDIR")
|
||||
tempTestDirPath, err := os.MkdirTemp("", "temp_test_dir")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, os.Setenv("QLOGDIR", originalQlogDirValue))
|
||||
require.NoError(t, os.RemoveAll(tempTestDirPath))
|
||||
})
|
||||
qlogDir := path.Join(tempTestDirPath, "qlogs")
|
||||
require.NoError(t, os.Setenv("QLOGDIR", qlogDir))
|
||||
|
||||
AfterEach(func() {
|
||||
err := os.Setenv("QLOGDIR", originalQlogDirValue)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.RemoveAll(tempTestDirPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
serverStopped := make(chan struct{})
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
&quic.Config{
|
||||
Tracer: qlog.DefaultConnectionTracer,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
handshake := func() {
|
||||
serverStopped := make(chan struct{})
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
&quic.Config{
|
||||
Tracer: qlog.DefaultConnectionTracer,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(serverStopped)
|
||||
for {
|
||||
if _, err := server.Accept(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer close(serverStopped)
|
||||
for {
|
||||
if _, err := server.Accept(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
server.Addr().String(),
|
||||
getTLSClientConfig(),
|
||||
&quic.Config{
|
||||
Tracer: qlog.DefaultConnectionTracer,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.CloseWithError(0, "")
|
||||
server.Close()
|
||||
<-serverStopped
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
server.Addr().String(),
|
||||
getTLSClientConfig(),
|
||||
&quic.Config{
|
||||
Tracer: qlog.DefaultConnectionTracer,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
conn.CloseWithError(0, "")
|
||||
server.Close()
|
||||
<-serverStopped
|
||||
|
||||
_, err = os.Stat(tempTestDirPath)
|
||||
qlogDirCreated := !os.IsNotExist(err)
|
||||
require.True(t, qlogDirCreated)
|
||||
|
||||
childs, err := os.ReadDir(qlogDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, childs, 2)
|
||||
|
||||
odcids := make([]string, 0, 2)
|
||||
vantagePoints := make([]string, 0, 2)
|
||||
qlogFileNameRegexp := regexp.MustCompile(`^([0-f]+)_(client|server).sqlog$`)
|
||||
|
||||
for _, child := range childs {
|
||||
matches := qlogFileNameRegexp.FindStringSubmatch(child.Name())
|
||||
require.Len(t, matches, 3)
|
||||
odcids = append(odcids, matches[1])
|
||||
vantagePoints = append(vantagePoints, matches[2])
|
||||
}
|
||||
|
||||
It("environment variable is set", func() {
|
||||
qlogDir := path.Join(tempTestDirPath, "qlogs")
|
||||
err := os.Setenv("QLOGDIR", qlogDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handshake()
|
||||
_, err = os.Stat(tempTestDirPath)
|
||||
qlogDirCreated := !os.IsNotExist(err)
|
||||
Expect(qlogDirCreated).To(BeTrue())
|
||||
childs, err := os.ReadDir(qlogDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(childs)).To(Equal(2))
|
||||
odcids := make([]string, 0)
|
||||
vantagePoints := make([]string, 0)
|
||||
qlogFileNameRegexp := regexp.MustCompile(`^([0-f]+)_(client|server).sqlog$`)
|
||||
for _, child := range childs {
|
||||
matches := qlogFileNameRegexp.FindStringSubmatch(child.Name())
|
||||
Expect(matches).To(HaveLen(3))
|
||||
odcids = append(odcids, matches[1])
|
||||
vantagePoints = append(vantagePoints, matches[2])
|
||||
}
|
||||
Expect(odcids[0]).To(Equal(odcids[1]))
|
||||
Expect(vantagePoints).To(ContainElements("client", "server"))
|
||||
})
|
||||
})
|
||||
require.Equal(t, odcids[0], odcids[1])
|
||||
require.Contains(t, vantagePoints, "client")
|
||||
require.Contains(t, vantagePoints, "server")
|
||||
}
|
||||
|
||||
@@ -5,19 +5,18 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
cache tls.ClientSessionCache
|
||||
|
||||
gets chan<- string
|
||||
puts chan<- string
|
||||
gets chan<- string
|
||||
puts chan<- string
|
||||
}
|
||||
|
||||
func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
|
||||
@@ -45,94 +44,18 @@ func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState)
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("TLS session resumption", func() {
|
||||
It("uses session resumption", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
var sessionKey string
|
||||
Eventually(puts).Should(Receive(&sessionKey))
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(gets).To(Receive(Equal(sessionKey)))
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||
conn2.CloseWithError(0, "")
|
||||
func TestTLSSessionResumption(t *testing.T) {
|
||||
t.Run("uses session resumption", func(t *testing.T) {
|
||||
handshakeWithSessionResumption(t, getTLSConfig(), true)
|
||||
})
|
||||
|
||||
It("doesn't use session resumption, if the config disables it", func() {
|
||||
t.Run("disabled in tls.Config", func(t *testing.T) {
|
||||
sConf := getTLSConfig()
|
||||
sConf.SessionTicketsDisabled = true
|
||||
server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
Consistently(puts).ShouldNot(Receive())
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn2.CloseWithError(0, "")
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
handshakeWithSessionResumption(t, sConf, false)
|
||||
})
|
||||
|
||||
It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() {
|
||||
t.Run("disabled in tls.Config.GetConfigForClient", func(t *testing.T) {
|
||||
sConf := &tls.Config{
|
||||
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
conf := getTLSConfig()
|
||||
@@ -140,45 +63,78 @@ var _ = Describe("TLS session resumption", func() {
|
||||
return conf, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Consistently(puts).ShouldNot(Receive())
|
||||
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn1.CloseWithError(0, "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
defer conn2.CloseWithError(0, "")
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
handshakeWithSessionResumption(t, sConf, false)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func handshakeWithSessionResumption(t *testing.T, serverTLSConf *tls.Config, expectSessionTicket bool) {
|
||||
server, err := quic.ListenAddr("localhost:0", serverTLSConf, getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
|
||||
// first connection - doesn't use resumption
|
||||
conn1, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn1.CloseWithError(0, "")
|
||||
require.False(t, conn1.ConnectionState().TLS.DidResume)
|
||||
|
||||
var sessionKey string
|
||||
select {
|
||||
case sessionKey = <-puts:
|
||||
if !expectSessionTicket {
|
||||
t.Fatal("unexpected session ticket")
|
||||
}
|
||||
case <-time.After(scaleDuration(50 * time.Millisecond)):
|
||||
if expectSessionTicket {
|
||||
t.Fatal("timeout waiting for session ticket")
|
||||
}
|
||||
}
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, serverConn.ConnectionState().TLS.DidResume)
|
||||
|
||||
// second connection - will use resumption, if enabled
|
||||
conn2, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn2.CloseWithError(0, "")
|
||||
|
||||
select {
|
||||
case k := <-gets:
|
||||
if expectSessionTicket {
|
||||
// we can only perform this check if we got a session ticket before
|
||||
require.Equal(t, sessionKey, k)
|
||||
}
|
||||
case <-time.After(scaleDuration(50 * time.Millisecond)):
|
||||
if expectSessionTicket {
|
||||
t.Fatal("timeout waiting for retrieval of session ticket")
|
||||
}
|
||||
}
|
||||
|
||||
serverConn, err = server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
if expectSessionTicket {
|
||||
require.True(t, conn2.ConnectionState().TLS.DidResume)
|
||||
require.True(t, serverConn.ConnectionState().TLS.DidResume)
|
||||
} else {
|
||||
require.False(t, conn2.ConnectionState().TLS.DidResume)
|
||||
require.False(t, serverConn.ConnectionState().TLS.DidResume)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,109 +5,134 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("non-zero RTT", func() {
|
||||
runServer := func() *quic.Listener {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
func runServerForRTTTest(t *testing.T) (net.Addr, <-chan error) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { ln.Close() })
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("accept error: %w", err)
|
||||
return
|
||||
}
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("open stream error: %w", err)
|
||||
return
|
||||
}
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("write error: %w", err)
|
||||
return
|
||||
}
|
||||
str.Close()
|
||||
}()
|
||||
return ln
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
downloadFile := func(port int) {
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
conn.CloseWithError(0, "")
|
||||
}
|
||||
return ln.Addr(), errChan
|
||||
}
|
||||
|
||||
for _, r := range [...]time.Duration{
|
||||
func TestDownloadWithFixedRTT(t *testing.T) {
|
||||
addr, errChan := runServerForRTTTest(t)
|
||||
|
||||
for _, rtt := range []time.Duration{
|
||||
10 * time.Millisecond,
|
||||
50 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
250 * time.Millisecond,
|
||||
} {
|
||||
rtt := r
|
||||
|
||||
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
|
||||
ln := runServer()
|
||||
defer ln.Close()
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
|
||||
return rtt / 2
|
||||
},
|
||||
t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Errorf("server error: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { proxy.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
str, err := conn.AcceptStream(ctx)
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
conn.CloseWithError(0, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range [...]time.Duration{
|
||||
10 * time.Millisecond,
|
||||
40 * time.Millisecond,
|
||||
func TestDownloadWithReordering(t *testing.T) {
|
||||
addr, errChan := runServerForRTTTest(t)
|
||||
|
||||
for _, rtt := range []time.Duration{
|
||||
5 * time.Millisecond,
|
||||
30 * time.Millisecond,
|
||||
} {
|
||||
rtt := r
|
||||
t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Errorf("server error: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("downloads a message with %s RTT, with reordering", rtt), func() {
|
||||
ln := runServer()
|
||||
defer ln.Close()
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
|
||||
return randomDuration(rtt/2, rtt*3/2) / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { proxy.Close() })
|
||||
|
||||
downloadFile(proxy.LocalPort())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
conn, err := quic.DialAddr(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
str, err := conn.AcceptStream(ctx)
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,24 +8,19 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/integrationtests/tools"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const alpn = tools.ALPN
|
||||
@@ -53,56 +48,17 @@ func GeneratePRData(l int) []byte {
|
||||
return res
|
||||
}
|
||||
|
||||
const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
|
||||
|
||||
type syncedBuffer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
*bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *syncedBuffer) Write(p []byte) (int, error) {
|
||||
b.mutex.Lock()
|
||||
n, err := b.Buffer.Write(p)
|
||||
b.mutex.Unlock()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *syncedBuffer) Bytes() []byte {
|
||||
b.mutex.Lock()
|
||||
p := b.Buffer.Bytes()
|
||||
b.mutex.Unlock()
|
||||
return p
|
||||
}
|
||||
|
||||
func (b *syncedBuffer) Reset() {
|
||||
b.mutex.Lock()
|
||||
b.Buffer.Reset()
|
||||
b.mutex.Unlock()
|
||||
}
|
||||
|
||||
var (
|
||||
logFileName string // the log file set in the ginkgo flags
|
||||
logBufOnce sync.Once
|
||||
logBuf *syncedBuffer
|
||||
versionParam string
|
||||
|
||||
version quic.Version
|
||||
enableQlog bool
|
||||
|
||||
version quic.Version
|
||||
tlsConfig *tls.Config
|
||||
tlsConfigLongChain *tls.Config
|
||||
tlsClientConfig *tls.Config
|
||||
tlsClientConfigWithoutServerName *tls.Config
|
||||
)
|
||||
|
||||
// read the logfile command line flag
|
||||
// to set call ginkgo -- -logfile=log.txt
|
||||
func init() {
|
||||
flag.StringVar(&logFileName, "logfile", "", "log file")
|
||||
flag.StringVar(&versionParam, "version", "1", "QUIC version")
|
||||
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
|
||||
|
||||
ca, caPrivateKey, err := tools.GenerateCA()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -137,31 +93,9 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
switch versionParam {
|
||||
case "1":
|
||||
version = quic.Version1
|
||||
case "2":
|
||||
version = quic.Version2
|
||||
default:
|
||||
Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
|
||||
}
|
||||
fmt.Printf("Using QUIC version: %s\n", version)
|
||||
protocol.SupportedVersions = []quic.Version{version}
|
||||
})
|
||||
|
||||
func getTLSConfig() *tls.Config {
|
||||
return tlsConfig.Clone()
|
||||
}
|
||||
|
||||
func getTLSConfigWithLongCertChain() *tls.Config {
|
||||
return tlsConfigLongChain.Clone()
|
||||
}
|
||||
|
||||
func getTLSClientConfig() *tls.Config {
|
||||
return tlsClientConfig.Clone()
|
||||
}
|
||||
|
||||
func getTLSConfig() *tls.Config { return tlsConfig.Clone() }
|
||||
func getTLSConfigWithLongCertChain() *tls.Config { return tlsConfigLongChain.Clone() }
|
||||
func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() }
|
||||
func getTLSClientConfigWithoutServerName() *tls.Config {
|
||||
return tlsClientConfigWithoutServerName.Clone()
|
||||
}
|
||||
@@ -178,7 +112,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
||||
if conf.Tracer == nil {
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
|
||||
tools.NewQlogConnectionTracer(os.Stdout)(ctx, p, connID),
|
||||
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
|
||||
&logging.ConnectionTracer{},
|
||||
)
|
||||
@@ -188,7 +122,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
||||
origTracer := conf.Tracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tr := origTracer(ctx, p, connID)
|
||||
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID)
|
||||
qlogger := tools.NewQlogConnectionTracer(os.Stdout)(ctx, p, connID)
|
||||
if tr == nil {
|
||||
return qlogger
|
||||
}
|
||||
@@ -203,7 +137,7 @@ func addTracer(tr *quic.Transport) {
|
||||
}
|
||||
if tr.Tracer == nil {
|
||||
tr.Tracer = logging.NewMultiplexedTracer(
|
||||
tools.QlogTracer(GinkgoWriter),
|
||||
tools.QlogTracer(os.Stdout),
|
||||
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
|
||||
&logging.Tracer{},
|
||||
)
|
||||
@@ -211,23 +145,11 @@ func addTracer(tr *quic.Transport) {
|
||||
}
|
||||
origTracer := tr.Tracer
|
||||
tr.Tracer = logging.NewMultiplexedTracer(
|
||||
tools.QlogTracer(GinkgoWriter),
|
||||
tools.QlogTracer(os.Stdout),
|
||||
origTracer,
|
||||
)
|
||||
}
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
|
||||
|
||||
if debugLog() {
|
||||
logBufOnce.Do(func() {
|
||||
logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
|
||||
})
|
||||
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
||||
log.SetOutput(logBuf)
|
||||
}
|
||||
})
|
||||
|
||||
func areHandshakesRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
@@ -240,22 +162,36 @@ func areTransportsRunning() bool {
|
||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
Expect(areHandshakesRunning()).To(BeFalse())
|
||||
Eventually(areTransportsRunning).Should(BeFalse())
|
||||
func TestMain(m *testing.M) {
|
||||
var versionParam string
|
||||
flag.StringVar(&versionParam, "version", "1", "QUIC version")
|
||||
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
|
||||
flag.Parse()
|
||||
|
||||
if debugLog() {
|
||||
logFile, err := os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
logFile.Write(logBuf.Bytes())
|
||||
logFile.Close()
|
||||
logBuf.Reset()
|
||||
switch versionParam {
|
||||
case "1":
|
||||
version = quic.Version1
|
||||
case "2":
|
||||
version = quic.Version2
|
||||
default:
|
||||
fmt.Printf("unknown QUIC version: %s\n", versionParam)
|
||||
os.Exit(1)
|
||||
}
|
||||
})
|
||||
fmt.Printf("using QUIC version: %s\n", version)
|
||||
|
||||
// Debug says if this test is being logged
|
||||
func debugLog() bool {
|
||||
return len(logFileName) > 0
|
||||
status := m.Run()
|
||||
if status != 0 {
|
||||
os.Exit(status)
|
||||
}
|
||||
if areHandshakesRunning() {
|
||||
fmt.Println("stray handshake goroutines found")
|
||||
os.Exit(1)
|
||||
}
|
||||
if areTransportsRunning() {
|
||||
fmt.Println("stray transport goroutines found")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(status)
|
||||
}
|
||||
|
||||
func scaleDuration(d time.Duration) time.Duration {
|
||||
@@ -301,6 +237,17 @@ func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
|
||||
return t.rcvdLongHdr
|
||||
}
|
||||
|
||||
func (t *packetCounter) getRcvd0RTTPacketNumbers() []protocol.PacketNumber {
|
||||
packets := t.getRcvdLongHeaderPackets()
|
||||
var zeroRTTPackets []protocol.PacketNumber
|
||||
for _, p := range packets {
|
||||
if p.hdr.Type == protocol.PacketType0RTT {
|
||||
zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
|
||||
}
|
||||
}
|
||||
return zeroRTTPackets
|
||||
}
|
||||
|
||||
func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
|
||||
<-t.closed
|
||||
return t.rcvdShortHdr
|
||||
@@ -345,7 +292,25 @@ func (r *readerWithTimeout) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelf(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Self integration tests")
|
||||
func randomDuration(min, max time.Duration) time.Duration {
|
||||
return min + time.Duration(rand.Int63n(int64(max-min)))
|
||||
}
|
||||
|
||||
// contains0RTTPacket says if a packet contains a 0-RTT long header packet.
|
||||
// It correctly handles coalesced packets.
|
||||
func contains0RTTPacket(data []byte) bool {
|
||||
for len(data) > 0 {
|
||||
if !wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
hdr, _, rest, err := wire.ParsePacket(data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
return true
|
||||
}
|
||||
data = rest
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -6,121 +6,119 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Stateless Resets", func() {
|
||||
connIDLens := []int{0, 10}
|
||||
func TestStatelessResets(t *testing.T) {
|
||||
t.Run("0 byte connection IDs", func(t *testing.T) {
|
||||
testStatelessReset(t, 0)
|
||||
})
|
||||
t.Run("10 byte connection IDs", func(t *testing.T) {
|
||||
testStatelessReset(t, 10)
|
||||
})
|
||||
}
|
||||
|
||||
for i := range connIDLens {
|
||||
connIDLen := connIDLens[i]
|
||||
func testStatelessReset(t *testing.T, connIDLen int) {
|
||||
var statelessResetKey quic.StatelessResetKey
|
||||
rand.Read(statelessResetKey[:])
|
||||
|
||||
It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() {
|
||||
var statelessResetKey quic.StatelessResetKey
|
||||
rand.Read(statelessResetKey[:])
|
||||
|
||||
c, err := net.ListenUDP("udp", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
|
||||
closeServer := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
<-closeServer
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(tr.Close()).To(Succeed())
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return drop.Load()
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
cl := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer cl.Close()
|
||||
conn, err := cl.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := make([]byte, 6)
|
||||
_, err = str.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
|
||||
// make sure that the CONNECTION_CLOSE is dropped
|
||||
drop.Store(true)
|
||||
close(closeServer)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// We need to create a new Transport here, since the old one is still sending out
|
||||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
drop.Store(false)
|
||||
|
||||
acceptStopped := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln2.Accept(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
close(acceptStopped)
|
||||
}()
|
||||
|
||||
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
|
||||
// If the client already sent another packet, it might already have received a packet.
|
||||
_, serr := str.Write([]byte("Lorem ipsum dolor sit amet."))
|
||||
if serr == nil {
|
||||
_, serr = str.Read([]byte{0})
|
||||
}
|
||||
Expect(serr).To(HaveOccurred())
|
||||
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
|
||||
Expect(ln2.Close()).To(Succeed())
|
||||
Eventually(acceptStopped).Should(BeClosed())
|
||||
})
|
||||
c, err := net.ListenUDP("udp", nil)
|
||||
require.NoError(t, err)
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
})
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
str, err := conn.OpenStream()
|
||||
if err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
if err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
close(serverErr)
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return drop.Load()
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
require.NoError(t, err)
|
||||
defer udpConn.Close()
|
||||
cl := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer cl.Close()
|
||||
conn, err := cl.Dial(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data := make([]byte, 6)
|
||||
_, err = str.Read(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foobar"), data)
|
||||
|
||||
// make sure that the CONNECTION_CLOSE is dropped
|
||||
drop.Store(true)
|
||||
require.NoError(t, ln.Close())
|
||||
require.NoError(t, tr.Close())
|
||||
require.NoError(t, <-serverErr)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// We need to create a new Transport here, since the old one is still sending out
|
||||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
drop.Store(false)
|
||||
|
||||
// Trigger something (not too small) to be sent, so that we receive the stateless reset.
|
||||
// If the client already sent another packet, it might already have received a packet.
|
||||
_, serr := str.Write([]byte("Lorem ipsum dolor sit amet."))
|
||||
if serr == nil {
|
||||
_, serr = str.Read([]byte{0})
|
||||
}
|
||||
require.Error(t, serr)
|
||||
require.IsType(t, &quic.StatelessResetError{}, serr)
|
||||
require.NoError(t, ln2.Close())
|
||||
}
|
||||
|
||||
@@ -1,156 +1,321 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Bidirectional streams", func() {
|
||||
const numStreams = 300
|
||||
func TestBidirectionalStreamMultiplexing(t *testing.T) {
|
||||
const numStreams = 75
|
||||
|
||||
var (
|
||||
server *quic.Listener
|
||||
serverAddr string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
runSendingPeer := func(conn quic.Connection) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
runSendingPeer := func(conn quic.Connection) error {
|
||||
g := new(errgroup.Group)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := GeneratePRData(25 * i)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := GeneratePRData(50 * i)
|
||||
g.Go(func() error {
|
||||
if _, err := str.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return str.Close()
|
||||
})
|
||||
g.Go(func() error {
|
||||
dataRead, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(dataRead).To(Equal(data))
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(dataRead, data) {
|
||||
return fmt.Errorf("data mismatch: %q != %q", dataRead, data)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
runReceivingPeer := func(conn quic.Connection) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
runReceivingPeer := func(conn quic.Connection) error {
|
||||
g := new(errgroup.Group)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.Go(func() error {
|
||||
// shouldn't use io.Copy here
|
||||
// we should read from the stream as early as possible, to free flow control credit
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := str.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return str.Close()
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
||||
var conn quic.Connection
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
conn, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(conn)
|
||||
}()
|
||||
t.Run("client -> server", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: 10,
|
||||
InitialStreamReceiveWindow: 10000,
|
||||
InitialConnectionReceiveWindow: 5000,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- runReceivingPeer(conn) }()
|
||||
require.NoError(t, runSendingPeer(client))
|
||||
client.CloseWithError(0, "")
|
||||
<-conn.Context().Done()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
t.Run("client <-> server", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: 30,
|
||||
InitialStreamReceiveWindow: 25000,
|
||||
InitialConnectionReceiveWindow: 50000,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan1 := make(chan error, 1)
|
||||
errChan2 := make(chan error, 1)
|
||||
errChan3 := make(chan error, 1)
|
||||
errChan4 := make(chan error, 1)
|
||||
|
||||
go func() { errChan1 <- runReceivingPeer(conn) }()
|
||||
go func() { errChan2 <- runSendingPeer(conn) }()
|
||||
go func() { errChan3 <- runReceivingPeer(client) }()
|
||||
go func() { errChan4 <- runSendingPeer(client) }()
|
||||
|
||||
for _, ch := range []chan error{errChan1, errChan2, errChan3, errChan4} {
|
||||
select {
|
||||
case err := <-ch:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
}
|
||||
|
||||
client.CloseWithError(0, "")
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnidirectionalStreams(t *testing.T) {
|
||||
const numStreams = 500
|
||||
|
||||
dataForStream := func(id uint64) []byte { return GeneratePRData(10 * int(id)) }
|
||||
|
||||
runSendingPeer := func(conn quic.Connection) error {
|
||||
g := new(errgroup.Group)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.OpenUniStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.Go(func() error {
|
||||
if _, err := str.Write(dataForStream(uint64(str.StreamID()))); err != nil {
|
||||
return err
|
||||
}
|
||||
return str.Close()
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
runReceivingPeer := func(conn quic.Connection) error {
|
||||
g := new(errgroup.Group)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.Go(func() error {
|
||||
data, err := io.ReadAll(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(data, dataForStream(uint64(str.StreamID()))) {
|
||||
return fmt.Errorf("data mismatch")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
t.Run("client -> server", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- runSendingPeer(client) }()
|
||||
require.NoError(t, runReceivingPeer(serverConn))
|
||||
serverConn.CloseWithError(0, "")
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server -> client", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
errChan := make(chan error, 1)
|
||||
go func() { errChan <- runSendingPeer(serverConn) }()
|
||||
|
||||
require.NoError(t, runReceivingPeer(client))
|
||||
client.CloseWithError(0, "")
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("client <-> server", func(t *testing.T) {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
errChan1 := make(chan error, 1)
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(conn)
|
||||
conn.CloseWithError(0, "")
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
errChan1 <- err
|
||||
errChan2 <- err
|
||||
return
|
||||
}
|
||||
errChan1 <- runReceivingPeer(conn)
|
||||
errChan2 <- runSendingPeer(conn)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
Eventually(client.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
|
||||
done1 := make(chan struct{})
|
||||
errChan3 := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runReceivingPeer(conn)
|
||||
close(done)
|
||||
}()
|
||||
runSendingPeer(conn)
|
||||
<-done
|
||||
close(done1)
|
||||
errChan3 <- runSendingPeer(client)
|
||||
}()
|
||||
require.NoError(t, runReceivingPeer(client))
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runSendingPeer(client)
|
||||
close(done2)
|
||||
}()
|
||||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
for _, ch := range []chan error{errChan1, errChan2, errChan3} {
|
||||
select {
|
||||
case err := <-ch:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.Fail(t, "timeout")
|
||||
}
|
||||
}
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,22 +1,366 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func requireIdleTimeoutError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
var idleTimeoutErr *quic.IdleTimeoutError
|
||||
require.ErrorAs(t, err, &idleTimeoutErr)
|
||||
require.True(t, idleTimeoutErr.Timeout())
|
||||
var nerr net.Error
|
||||
require.True(t, errors.As(err, &nerr))
|
||||
require.True(t, nerr.Timeout())
|
||||
}
|
||||
|
||||
func TestHandshakeIdleTimeout(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
requireIdleTimeoutError(t, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for dial error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeTimeoutContext(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for dial error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeTimeout0RTTContext(t *testing.T) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := quic.DialAddrEarly(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for dial error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeout(t *testing.T) {
|
||||
idleTimeout := scaleDuration(200 * time.Millisecond)
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return drop.Load()
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
str, err := serverConn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
|
||||
strIn, err := conn.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
strOut, err := conn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
_, err = strIn.Read(make([]byte, 6))
|
||||
require.NoError(t, err)
|
||||
|
||||
drop.Store(true)
|
||||
time.Sleep(2 * idleTimeout)
|
||||
_, err = strIn.Write([]byte("test"))
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = strIn.Read([]byte{0})
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = strOut.Write([]byte("test"))
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = strOut.Read([]byte{0})
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = conn.OpenStream()
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = conn.OpenUniStream()
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = conn.AcceptStream(context.Background())
|
||||
requireIdleTimeoutError(t, err)
|
||||
_, err = conn.AcceptUniStream(context.Background())
|
||||
requireIdleTimeoutError(t, err)
|
||||
}
|
||||
|
||||
func TestKeepAlive(t *testing.T) {
|
||||
idleTimeout := scaleDuration(150 * time.Millisecond)
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
KeepAlivePeriod: idleTimeout / 2,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait longer than the idle timeout
|
||||
time.Sleep(3 * idleTimeout)
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify connection is still alive
|
||||
select {
|
||||
case <-serverConn.Context().Done():
|
||||
t.Fatal("server connection closed unexpectedly")
|
||||
default:
|
||||
}
|
||||
|
||||
// idle timeout will still kick in if PINGs are dropped
|
||||
drop.Store(true)
|
||||
time.Sleep(2 * idleTimeout)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
var nerr net.Error
|
||||
require.True(t, errors.As(err, &nerr))
|
||||
require.True(t, nerr.Timeout())
|
||||
|
||||
// can't rely on the server connection closing, since we impose a minimum idle timeout of 5s,
|
||||
// see https://github.com/quic-go/quic-go/issues/4751
|
||||
serverConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func TestTimeoutAfterInactivity(t *testing.T) {
|
||||
idleTimeout := scaleDuration(150 * time.Millisecond)
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
counter, tr := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
Tracer: newTracer(tr),
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout)
|
||||
defer cancel()
|
||||
_, err = conn.AcceptStream(ctx)
|
||||
requireIdleTimeoutError(t, err)
|
||||
|
||||
var lastAckElicitingPacketSentAt time.Time
|
||||
for _, p := range counter.getSentShortHeaderPackets() {
|
||||
var hasAckElicitingFrame bool
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.AckFrame); ok {
|
||||
continue
|
||||
}
|
||||
hasAckElicitingFrame = true
|
||||
break
|
||||
}
|
||||
if hasAckElicitingFrame {
|
||||
lastAckElicitingPacketSentAt = p.time
|
||||
}
|
||||
}
|
||||
rcvdPackets := counter.getRcvdShortHeaderPackets()
|
||||
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
|
||||
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
|
||||
// This is ok since we're dealing with a lossless connection here,
|
||||
// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
|
||||
timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt)
|
||||
timeSinceLastRcvd := time.Since(lastPacketRcvdAt)
|
||||
maxDuration := max(timeSinceLastAckEliciting, timeSinceLastRcvd)
|
||||
require.GreaterOrEqual(t, maxDuration, idleTimeout)
|
||||
require.Less(t, maxDuration, idleTimeout*6/5)
|
||||
|
||||
select {
|
||||
case <-serverConn.Context().Done():
|
||||
t.Fatal("server connection closed unexpectedly")
|
||||
default:
|
||||
}
|
||||
|
||||
serverConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func TestTimeoutAfterSendingPacket(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("This test is flaky on Windows due to low timer precision.")
|
||||
}
|
||||
idleTimeout := scaleDuration(150 * time.Millisecond)
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(d quicproxy.Direction, _ []byte) bool { return d == quicproxy.DirectionOutgoing && drop.Load() },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait half the idle timeout, then send a packet
|
||||
time.Sleep(idleTimeout / 2)
|
||||
drop.Store(true)
|
||||
str, err := conn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// now make sure that the idle timeout is based on this packet
|
||||
startTime := time.Now()
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout)
|
||||
defer cancel()
|
||||
_, err = conn.AcceptStream(ctx)
|
||||
requireIdleTimeoutError(t, err)
|
||||
dur := time.Since(startTime)
|
||||
require.GreaterOrEqual(t, dur, idleTimeout)
|
||||
require.Less(t, dur, idleTimeout*12/10)
|
||||
|
||||
// Verify server connection is still open
|
||||
select {
|
||||
case <-serverConn.Context().Done():
|
||||
t.Fatal("server connection closed unexpectedly")
|
||||
default:
|
||||
}
|
||||
serverConn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
type faultyConn struct {
|
||||
net.PacketConn
|
||||
|
||||
@@ -41,490 +385,136 @@ func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
var _ = Describe("Timeout tests", func() {
|
||||
checkTimeoutError := func(err error) {
|
||||
ExpectWithOffset(1, err).To(MatchError(&quic.IdleTimeoutError{}))
|
||||
nerr, ok := err.(net.Error)
|
||||
ExpectWithOffset(1, ok).To(BeTrue())
|
||||
ExpectWithOffset(1, nerr.Timeout()).To(BeTrue())
|
||||
func TestFaultyPacketConn(t *testing.T) {
|
||||
t.Run("client", func(t *testing.T) {
|
||||
testFaultyPacketConn(t, protocol.PerspectiveClient)
|
||||
})
|
||||
|
||||
t.Run("server", func(t *testing.T) {
|
||||
testFaultyPacketConn(t, protocol.PerspectiveServer)
|
||||
})
|
||||
}
|
||||
|
||||
func testFaultyPacketConn(t *testing.T, pers protocol.Perspective) {
|
||||
handshakeTimeout := scaleDuration(100 * time.Millisecond)
|
||||
|
||||
runServer := func(ln *quic.Listener) error {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRData)
|
||||
return err
|
||||
}
|
||||
|
||||
It("returns net.Error timeout errors when dialing", func() {
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
"localhost:12345",
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
checkTimeoutError(err)
|
||||
})
|
||||
runClient := func(conn quic.Connection) error {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := io.ReadAll(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(data, PRData) {
|
||||
return fmt.Errorf("wrong data: %q vs %q", data, PRData)
|
||||
}
|
||||
return conn.CloseWithError(0, "done")
|
||||
}
|
||||
|
||||
It("returns the context error when the context expires", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := quic.DialAddr(
|
||||
ctx,
|
||||
"localhost:12345",
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
// This is not a net.Error timeout error
|
||||
Expect(err).To(MatchError(context.DeadlineExceeded))
|
||||
})
|
||||
var cconn, sconn net.PacketConn
|
||||
var err error
|
||||
cconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer cconn.Close()
|
||||
sconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer sconn.Close()
|
||||
|
||||
It("returns the context error when the context expires with 0RTT enabled", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := quic.DialAddrEarly(
|
||||
ctx,
|
||||
"localhost:12345",
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
errChan <- err
|
||||
}()
|
||||
var err error
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
// This is not a net.Error timeout error
|
||||
Expect(err).To(MatchError(context.DeadlineExceeded))
|
||||
})
|
||||
maxPackets := mrand.Int31n(25)
|
||||
t.Logf("blocking %s's connection after %d packets", pers, maxPackets)
|
||||
switch pers {
|
||||
case protocol.PerspectiveClient:
|
||||
cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets}
|
||||
case protocol.PerspectiveServer:
|
||||
sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets}
|
||||
}
|
||||
|
||||
It("returns net.Error timeout errors when an idle timeout occurs", func() {
|
||||
const idleTimeout = 500 * time.Millisecond
|
||||
ln, err := quic.Listen(
|
||||
sconn,
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
HandshakeIdleTimeout: handshakeTimeout,
|
||||
MaxIdleTimeout: handshakeTimeout,
|
||||
KeepAlivePeriod: handshakeTimeout / 2,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() { serverErrChan <- runServer(ln) }()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return drop.Load()
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strIn, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strOut, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strIn.Read(make([]byte, 6))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
drop.Store(true)
|
||||
time.Sleep(2 * idleTimeout)
|
||||
_, err = strIn.Write([]byte("test"))
|
||||
checkTimeoutError(err)
|
||||
_, err = strIn.Read([]byte{0})
|
||||
checkTimeoutError(err)
|
||||
_, err = strOut.Write([]byte("test"))
|
||||
checkTimeoutError(err)
|
||||
_, err = strOut.Read([]byte{0})
|
||||
checkTimeoutError(err)
|
||||
_, err = conn.OpenStream()
|
||||
checkTimeoutError(err)
|
||||
_, err = conn.OpenUniStream()
|
||||
checkTimeoutError(err)
|
||||
_, err = conn.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
_, err = conn.AcceptUniStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
})
|
||||
|
||||
Context("timing out at the right time", func() {
|
||||
var idleTimeout time.Duration
|
||||
|
||||
BeforeEach(func() {
|
||||
idleTimeout = scaleDuration(500 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("times out after inactivity", func() {
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
||||
counter, tr := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
Tracer: newTracer(tr),
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := conn.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done, 2*idleTimeout).Should(BeClosed())
|
||||
var lastAckElicitingPacketSentAt time.Time
|
||||
for _, p := range counter.getSentShortHeaderPackets() {
|
||||
var hasAckElicitingFrame bool
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.AckFrame); ok {
|
||||
continue
|
||||
}
|
||||
hasAckElicitingFrame = true
|
||||
break
|
||||
}
|
||||
if hasAckElicitingFrame {
|
||||
lastAckElicitingPacketSentAt = p.time
|
||||
}
|
||||
}
|
||||
rcvdPackets := counter.getRcvdShortHeaderPackets()
|
||||
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
|
||||
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
|
||||
// This is ok since we're dealing with a lossless connection here,
|
||||
// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
|
||||
Expect(max(time.Since(lastAckElicitingPacketSentAt), time.Since(lastPacketRcvdAt))).To(And(
|
||||
BeNumerically(">=", idleTimeout),
|
||||
BeNumerically("<", idleTimeout*6/5),
|
||||
))
|
||||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("times out after sending a packet", func() {
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(dir quicproxy.Direction, _ []byte) bool {
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
return drop.Load()
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
<-conn.Context().Done() // block until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// wait half the idle timeout, then send a packet
|
||||
time.Sleep(idleTimeout / 2)
|
||||
drop.Store(true)
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// now make sure that the idle timeout is based on this packet
|
||||
startTime := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := conn.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done, 2*idleTimeout).Should(BeClosed())
|
||||
dur := time.Since(startTime)
|
||||
Expect(dur).To(And(
|
||||
BeNumerically(">=", idleTimeout),
|
||||
BeNumerically("<", idleTimeout*12/10),
|
||||
))
|
||||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
It("does not time out if keepalive is set", func() {
|
||||
const idleTimeout = 500 * time.Millisecond
|
||||
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
serverConnChan := make(chan quic.Connection, 1)
|
||||
serverConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConnChan <- conn
|
||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||
close(serverConnClosed)
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool {
|
||||
return drop.Load()
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
cconn,
|
||||
ln.Addr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
KeepAlivePeriod: idleTimeout / 2,
|
||||
HandshakeIdleTimeout: handshakeTimeout,
|
||||
MaxIdleTimeout: handshakeTimeout,
|
||||
KeepAlivePeriod: handshakeTimeout / 2,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// wait longer than the idle timeout
|
||||
time.Sleep(3 * idleTimeout)
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||
|
||||
// idle timeout will still kick in if pings are dropped
|
||||
drop.Store(true)
|
||||
time.Sleep(2 * idleTimeout)
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
checkTimeoutError(err)
|
||||
|
||||
(<-serverConnChan).CloseWithError(0, "")
|
||||
Eventually(serverConnClosed).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("faulty packet conns", func() {
|
||||
const handshakeTimeout = time.Second / 2
|
||||
|
||||
runServer := func(ln *quic.Listener) error {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer str.Close()
|
||||
_, err = str.Write(PRData)
|
||||
return err
|
||||
if err != nil {
|
||||
clientErrChan <- err
|
||||
return
|
||||
}
|
||||
clientErrChan <- runClient(conn)
|
||||
}()
|
||||
|
||||
runClient := func(conn quic.Connection) error {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := io.ReadAll(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
Expect(data).To(Equal(PRData))
|
||||
return conn.CloseWithError(0, "done")
|
||||
var clientErr error
|
||||
select {
|
||||
case clientErr = <-clientErrChan:
|
||||
case <-time.After(5 * handshakeTimeout):
|
||||
t.Fatal("timeout waiting for client error")
|
||||
}
|
||||
require.Error(t, clientErr)
|
||||
if pers == protocol.PerspectiveClient {
|
||||
require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error())
|
||||
} else {
|
||||
var nerr net.Error
|
||||
require.True(t, errors.As(clientErr, &nerr))
|
||||
require.True(t, nerr.Timeout())
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool { return !areHandshakesRunning() }, 5*handshakeTimeout, 5*time.Millisecond)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrChan: // The handshake completed on the server side.
|
||||
require.Error(t, serverErr)
|
||||
if pers == protocol.PerspectiveServer {
|
||||
require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error())
|
||||
} else {
|
||||
var nerr net.Error
|
||||
require.True(t, errors.As(serverErr, &nerr))
|
||||
require.True(t, nerr.Timeout())
|
||||
}
|
||||
|
||||
It("deals with an erroring packet conn, on the server side", func() {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
maxPackets := mrand.Int31n(25)
|
||||
fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
|
||||
ln, err := quic.Listen(
|
||||
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serverErrChan <- runServer(ln)
|
||||
}()
|
||||
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
HandshakeIdleTimeout: handshakeTimeout,
|
||||
MaxIdleTimeout: handshakeTimeout,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
clientErrChan <- err
|
||||
return
|
||||
}
|
||||
clientErrChan <- runClient(conn)
|
||||
}()
|
||||
|
||||
var clientErr error
|
||||
Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
|
||||
Expect(clientErr).To(HaveOccurred())
|
||||
nErr, ok := clientErr.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(nErr.Timeout()).To(BeTrue())
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrChan:
|
||||
Expect(serverErr).To(HaveOccurred())
|
||||
Expect(serverErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
|
||||
defer ln.Close()
|
||||
default:
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Eventually(serverErrChan).Should(Receive())
|
||||
}
|
||||
})
|
||||
|
||||
It("deals with an erroring packet conn, on the client side", func() {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
HandshakeIdleTimeout: handshakeTimeout,
|
||||
MaxIdleTimeout: handshakeTimeout,
|
||||
KeepAlivePeriod: handshakeTimeout / 2,
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serverErrChan <- runServer(ln)
|
||||
}()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
maxPackets := mrand.Int31n(25)
|
||||
fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets)
|
||||
clientErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
|
||||
ln.Addr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
|
||||
)
|
||||
if err != nil {
|
||||
clientErrChan <- err
|
||||
return
|
||||
}
|
||||
clientErrChan <- runClient(conn)
|
||||
}()
|
||||
|
||||
var clientErr error
|
||||
Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr))
|
||||
Expect(clientErr).To(HaveOccurred())
|
||||
Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error()))
|
||||
Eventually(areHandshakesRunning, 5*handshakeTimeout).Should(BeFalse())
|
||||
select {
|
||||
case serverErr := <-serverErrChan: // The handshake completed on the server side.
|
||||
Expect(serverErr).To(HaveOccurred())
|
||||
nErr, ok := serverErr.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(nErr.Timeout()).To(BeTrue())
|
||||
default: // The handshake didn't complete
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Eventually(serverErrChan).Should(Receive())
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
default: // The handshake didn't complete
|
||||
require.NoError(t, ln.Close())
|
||||
select {
|
||||
case <-serverErrChan:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for server to close")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
@@ -17,30 +18,29 @@ import (
|
||||
"github.com/quic-go/quic-go/metrics"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ = Describe("Tracer tests", func() {
|
||||
func TestTracerHandshake(t *testing.T) {
|
||||
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
|
||||
enableQlog := mrand.Int()%2 != 0
|
||||
enableMetrcis := mrand.Int()%2 != 0
|
||||
enableMetrics := mrand.Int()%2 != 0
|
||||
enableCustomTracer := mrand.Int()%2 != 0
|
||||
|
||||
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, metrics: %t, custom: %t\n", pers, enableQlog, enableMetrcis, enableCustomTracer)
|
||||
t.Logf("%s using qlog: %t, metrics: %t, custom: %t", pers, enableQlog, enableMetrics, enableCustomTracer)
|
||||
|
||||
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
|
||||
if enableQlog {
|
||||
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %s\n", p, connID)
|
||||
t.Logf("%s qlog tracer deciding to not trace connection %s", p, connID)
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %s\n", p, connID)
|
||||
t.Logf("%s qlog tracing connection %s", p, connID)
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
|
||||
})
|
||||
}
|
||||
if enableMetrcis {
|
||||
if enableMetrics {
|
||||
tracerConstructors = append(tracerConstructors, metrics.DefaultConnectionTracer)
|
||||
}
|
||||
if enableCustomTracer {
|
||||
@@ -62,39 +62,21 @@ var _ = Describe("Tracer tests", func() {
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
It("handshakes with a random combination of tracers", func() {
|
||||
t.Run(fmt.Sprintf("run %d", i+1), func(t *testing.T) {
|
||||
if enableQlog {
|
||||
Skip("This test sets tracers and won't produce any qlogs.")
|
||||
t.Skip("This test sets tracers and won't produce any qlogs.")
|
||||
}
|
||||
|
||||
quicClientConf := addTracers(protocol.PerspectiveClient, getQuicConfig(nil))
|
||||
quicServerConf := addTracers(protocol.PerspectiveServer, getQuicConfig(nil))
|
||||
|
||||
serverChan := make(chan *quic.Listener)
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(serverDone)
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverChan <- ln
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}
|
||||
}()
|
||||
|
||||
ln := <-serverChan
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, ln.Close()) })
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < 3; j++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := quic.DialAddr(
|
||||
@@ -103,18 +85,27 @@ var _ = Describe("Tracer tests", func() {
|
||||
getTLSClientConfig(),
|
||||
quicClientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sstr, err := sconn.OpenUniStream()
|
||||
require.NoError(t, err)
|
||||
_, err = sstr.Write(PRData)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, sstr.Close())
|
||||
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(PRData))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
ln.Close()
|
||||
Eventually(serverDone).Should(BeClosed())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Unidirectional Streams", func() {
|
||||
const numStreams = 500
|
||||
|
||||
var (
|
||||
server *quic.Listener
|
||||
serverAddr string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
dataForStream := func(id protocol.StreamID) []byte {
|
||||
return GeneratePRData(10 * int(id))
|
||||
}
|
||||
|
||||
runSendingPeer := func(conn quic.Connection) {
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write(dataForStream(str.StreamID()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
runReceivingPeer := func(conn quic.Connection) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(dataForStream(str.StreamID())))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(conn)
|
||||
conn.CloseWithError(0, "")
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
<-client.Context().Done()
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(conn)
|
||||
<-conn.Context().Done()
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runReceivingPeer(conn)
|
||||
close(done)
|
||||
}()
|
||||
runSendingPeer(conn)
|
||||
<-done
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
serverAddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runSendingPeer(client)
|
||||
close(done2)
|
||||
}()
|
||||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user