Files
quic-go/integrationtests/versionnegotiation/handshake_test.go
Marten Seemann 7d54aa41a3 remove unneeded Connection.GetVersion method (#4792)
Instead, use Connection.ConnectionState().Version.
2024-12-23 20:29:26 +08:00

187 lines
6.8 KiB
Go

package versionnegotiation
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/require"
)
type result struct {
loggedVersions bool
receivedVersionNegotiation bool
chosen logging.Version
clientVersions, serverVersions []logging.Version
}
func newVersionNegotiationTracer(t *testing.T) (*result, *logging.ConnectionTracer) {
r := &result{}
return r, &logging.ConnectionTracer{
NegotiatedVersion: func(chosen logging.Version, clientVersions, serverVersions []logging.Version) {
if r.loggedVersions {
t.Fatal("only expected one call to NegotiatedVersions")
}
r.loggedVersions = true
r.chosen = chosen
r.clientVersions = clientVersions
r.serverVersions = serverVersions
},
ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.Version) {
r.receivedVersionNegotiation = true
},
}
}
func TestServerSupportsMoreVersionsThanClient(t *testing.T) {
supportedVersions := append([]quic.Version{}, protocol.SupportedVersions...)
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...)
defer func() { protocol.SupportedVersions = supportedVersions }()
expectedVersion := protocol.SupportedVersions[0]
serverConfig := &quic.Config{}
serverConfig.Versions = []protocol.Version{7, 8, protocol.SupportedVersions[0], 9}
serverResult, serverTracer := newVersionNegotiationTracer(t)
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
require.NoError(t, err)
defer server.Close()
clientResult, clientTracer := newVersionNegotiationTracer(t)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
}}),
)
require.NoError(t, err)
sconn, err := server.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, expectedVersion, sconn.ConnectionState().Version)
require.Equal(t, expectedVersion, conn.ConnectionState().Version)
require.NoError(t, conn.CloseWithError(0, ""))
select {
case <-sconn.Context().Done():
// Expected behavior
case <-time.After(5 * time.Second):
t.Fatal("Timeout waiting for connection to close")
}
require.Equal(t, expectedVersion, clientResult.chosen)
require.False(t, clientResult.receivedVersionNegotiation)
require.Equal(t, protocol.SupportedVersions, clientResult.clientVersions)
require.Empty(t, clientResult.serverVersions)
require.Equal(t, expectedVersion, serverResult.chosen)
require.Equal(t, serverConfig.Versions, serverResult.serverVersions)
require.Empty(t, serverResult.clientVersions)
}
func TestClientSupportsMoreVersionsThanServer(t *testing.T) {
supportedVersions := append([]quic.Version{}, protocol.SupportedVersions...)
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...)
defer func() { protocol.SupportedVersions = supportedVersions }()
expectedVersion := protocol.SupportedVersions[0]
// The server doesn't support the highest supported version, which is the first one the client will try,
// but it supports a bunch of versions that the client doesn't speak
serverResult, serverTracer := newVersionNegotiationTracer(t)
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
require.NoError(t, err)
defer server.Close()
clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10}
clientResult, clientTracer := newVersionNegotiationTracer(t)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
},
}),
)
require.NoError(t, err)
sconn, err := server.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, expectedVersion, sconn.ConnectionState().Version)
require.Equal(t, protocol.SupportedVersions[0], conn.ConnectionState().Version)
require.NoError(t, conn.CloseWithError(0, ""))
select {
case <-sconn.Context().Done():
// Expected behavior
case <-time.After(5 * time.Second):
t.Fatal("Timeout waiting for connection to close")
}
require.Equal(t, expectedVersion, clientResult.chosen)
require.True(t, clientResult.receivedVersionNegotiation)
require.Equal(t, clientVersions, clientResult.clientVersions)
require.Subset(t, clientResult.serverVersions, supportedVersions) // may contain greased versions
require.Equal(t, expectedVersion, serverResult.chosen)
require.Equal(t, serverConfig.Versions, serverResult.serverVersions)
require.Empty(t, serverResult.clientVersions)
}
func TestServerDisablesVersionNegotiation(t *testing.T) {
// The server doesn't support the highest supported version, which is the first one the client will try,
// but it supports a bunch of versions that the client doesn't speak
_, serverTracer := newVersionNegotiationTracer(t)
serverConfig := &quic.Config{Versions: []protocol.Version{quic.Version1}}
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
tr := &quic.Transport{
Conn: conn,
DisableVersionNegotiationPackets: true,
}
ln, err := tr.Listen(getTLSConfig(), serverConfig)
require.NoError(t, err)
defer ln.Close()
clientVersions := []protocol.Version{quic.Version2}
clientResult, clientTracer := newVersionNegotiationTracer(t)
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
},
HandshakeIdleTimeout: 100 * time.Millisecond,
}),
)
require.Error(t, err)
var nerr net.Error
require.True(t, errors.As(err, &nerr))
require.True(t, nerr.Timeout())
require.False(t, clientResult.receivedVersionNegotiation)
}