Files
quic-go/internal/handshake/crypto_setup_test.go
2025-11-14 04:04:40 +03:00

569 lines
17 KiB
Go

package handshake
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"math/big"
"net"
"testing"
"time"
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
"git.geeks-team.ru/gr1ffon/quic-go/internal/testdata"
"git.geeks-team.ru/gr1ffon/quic-go/internal/utils"
"git.geeks-team.ru/gr1ffon/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
const (
typeClientHello = 1
typeNewSessionTicket = 4
)
type mockClientSessionCache struct {
cache tls.ClientSessionCache
puts chan *tls.ClientSessionState
}
var _ tls.ClientSessionCache = &mockClientSessionCache{}
func newMockClientSessionCache() *mockClientSessionCache {
return &mockClientSessionCache{
puts: make(chan *tls.ClientSessionState, 1),
cache: tls.NewLRUClientSessionCache(1),
}
}
func (m *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
return m.cache.Get(sessionKey)
}
func (m *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
m.puts <- cs
m.cache.Put(sessionKey, cs)
}
func getTLSConfigs() (clientConf, serverConf *tls.Config) {
clientConf = &tls.Config{
ServerName: "localhost",
RootCAs: testdata.GetRootCA(),
NextProtos: []string{"crypto-setup"},
}
serverConf = testdata.GetTLSConfig()
serverConf.NextProtos = []string{"crypto-setup"}
return clientConf, serverConf
}
func TestErrorBeforeClientHelloGeneration(t *testing.T) {
tlsConf := testdata.GetTLSConfig()
tlsConf.InsecureSkipVerify = true
tlsConf.NextProtos = []string{""}
cl := NewCryptoSetupClient(
protocol.ConnectionID{},
&wire.TransportParameters{},
tlsConf,
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
var terr *qerr.TransportError
err := cl.StartHandshake(context.Background())
require.True(t, errors.As(err, &terr))
require.Equal(t, uint64(0x100+0x50), uint64(terr.ErrorCode))
require.Contains(t, err.Error(), "tls: invalid NextProtos value")
}
func TestMessageReceivedAtWrongEncryptionLevel(t *testing.T) {
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
protocol.ConnectionID{},
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
&wire.TransportParameters{StatelessResetToken: &token},
testdata.GetTLSConfig(),
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
require.NoError(t, server.StartHandshake(context.Background()))
fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
// wrong encryption level
err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
require.Error(t, err)
require.Contains(t, err.Error(), "tls: handshake data received at wrong level")
}
// The clientEvents and serverEvents contain all events that were not processed by the function,
// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
func handshake(t *testing.T, client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
t.Helper()
require.NoError(t, client.StartHandshake(context.Background()))
require.NoError(t, server.StartHandshake(context.Background()))
var clientHandshakeComplete, serverHandshakeComplete bool
for {
clientLoop:
for {
ev := client.NextEvent()
switch ev.Kind {
case EventNoEvent:
break clientLoop
case EventWriteInitialData:
serverErr = server.HandleMessage(ev.Data, protocol.EncryptionInitial)
if serverErr != nil {
return
}
case EventWriteHandshakeData:
serverErr = server.HandleMessage(ev.Data, protocol.EncryptionHandshake)
if serverErr != nil {
return
}
case EventHandshakeComplete:
clientHandshakeComplete = true
default:
clientEvents = append(clientEvents, ev)
}
}
serverLoop:
for {
ev := server.NextEvent()
switch ev.Kind {
case EventNoEvent:
break serverLoop
case EventWriteInitialData:
clientErr = client.HandleMessage(ev.Data, protocol.EncryptionInitial)
if clientErr != nil {
return
}
case EventWriteHandshakeData:
clientErr = client.HandleMessage(ev.Data, protocol.EncryptionHandshake)
if clientErr != nil {
return
}
case EventHandshakeComplete:
serverHandshakeComplete = true
ticket, err := server.GetSessionTicket()
require.NoError(t, err)
if ticket != nil {
require.NoError(t, client.HandleMessage(ticket, protocol.Encryption1RTT))
}
default:
serverEvents = append(serverEvents, ev)
}
}
if clientHandshakeComplete && serverHandshakeComplete {
break
}
}
return
}
func handshakeWithTLSConf(
t *testing.T,
clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *utils.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
) {
t.Helper()
client := NewCryptoSetupClient(
protocol.ConnectionID{},
clientTransportParameters,
clientConf,
enable0RTT,
clientRTTStats,
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
if serverTransportParameters.StatelessResetToken == nil {
var token protocol.StatelessResetToken
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer(
protocol.ConnectionID{},
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
serverTransportParameters,
serverConf,
enable0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
cEvents, cErr, sEvents, sErr := handshake(t, client, server)
return client, cEvents, cErr, server, sEvents, sErr
}
func TestHandshake(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
}
func TestHelloRetryRequest(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
}
func TestWithClientAuth(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.PureEd25519,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv)
require.NoError(t, err)
clientCert := tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}
clientConf, serverConf := getTLSConfigs()
clientConf.Certificates = []tls.Certificate{clientCert}
serverConf.ClientAuth = tls.RequireAnyClientCert
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
}
func TestTransportParameters(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
client := NewCryptoSetupClient(
protocol.ConnectionID{},
cTransportParameters,
clientConf,
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
var token protocol.StatelessResetToken
sTransportParameters := &wire.TransportParameters{
MaxIdleTimeout: 1337 * time.Second,
StatelessResetToken: &token,
ActiveConnectionIDLimit: 2,
}
server := NewCryptoSetupServer(
protocol.ConnectionID{},
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
sTransportParameters,
serverConf,
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
clientEvents, cErr, serverEvents, sErr := handshake(t, client, server)
require.NoError(t, cErr)
require.NoError(t, sErr)
var clientReceivedTransportParameters *wire.TransportParameters
for _, ev := range clientEvents {
if ev.Kind == EventReceivedTransportParameters {
clientReceivedTransportParameters = ev.TransportParameters
}
}
require.NotNil(t, clientReceivedTransportParameters)
require.Equal(t, 1337*time.Second, clientReceivedTransportParameters.MaxIdleTimeout)
var serverReceivedTransportParameters *wire.TransportParameters
for _, ev := range serverEvents {
if ev.Kind == EventReceivedTransportParameters {
serverReceivedTransportParameters = ev.TransportParameters
}
}
require.NotNil(t, serverReceivedTransportParameters)
require.Equal(t, 42*time.Second, serverReceivedTransportParameters.MaxIdleTimeout)
}
func TestNewSessionTicketAtWrongEncryptionLevel(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
// inject an invalid session ticket
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
err := client.HandleMessage(b, protocol.EncryptionHandshake)
require.Error(t, err)
require.Contains(t, err.Error(), "tls: handshake data received at wrong level")
}
func TestHandlingNewSessionTicketFails(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
// inject an invalid session ticket
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
err := client.HandleMessage(b, protocol.Encryption1RTT)
require.IsType(t, &qerr.TransportError{}, err)
require.True(t, err.(*qerr.TransportError).ErrorCode.IsCryptoError())
}
func TestSessionResumption(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
csc := newMockClientSessionCache()
clientConf.ClientSessionCache = csc
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
case <-time.After(time.Second):
t.Fatal("didn't receive a session ticket")
}
require.False(t, server.ConnectionState().DidResume)
require.False(t, client.ConnectionState().DidResume)
clientRTTStats := utils.NewRTTStats()
serverRTTStats := utils.NewRTTStats()
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
t,
clientConf, serverConf,
clientRTTStats, serverRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
case <-time.After(time.Second):
t.Fatal("didn't receive a session ticket")
}
require.True(t, server.ConnectionState().DidResume)
require.True(t, client.ConnectionState().DidResume)
}
func TestSessionResumptionDisabled(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
csc := newMockClientSessionCache()
clientConf.ClientSessionCache = csc
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
case <-time.After(time.Second):
t.Fatal("didn't receive a session ticket")
}
require.False(t, server.ConnectionState().DidResume)
require.False(t, client.ConnectionState().DidResume)
serverConf.SessionTicketsDisabled = true
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
t.Fatal("didn't expect to receive a session ticket")
case <-time.After(25 * time.Millisecond):
}
require.False(t, server.ConnectionState().DidResume)
require.False(t, client.ConnectionState().DidResume)
}
func Test0RTT(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
csc := newMockClientSessionCache()
clientConf.ClientSessionCache = csc
const initialMaxData protocol.ByteCount = 1337
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
true,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
case <-time.After(time.Second):
t.Fatal("didn't receive a session ticket")
}
require.False(t, server.ConnectionState().DidResume)
require.False(t, client.ConnectionState().DidResume)
client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
true,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
var tp *wire.TransportParameters
var clientReceived0RTTKeys bool
for _, ev := range clientEvents {
switch ev.Kind {
case EventRestoredTransportParameters:
tp = ev.TransportParameters
case EventReceivedReadKeys:
clientReceived0RTTKeys = true
}
}
require.True(t, clientReceived0RTTKeys)
require.NotNil(t, tp)
require.Equal(t, initialMaxData, tp.InitialMaxData)
var serverReceived0RTTKeys bool
for _, ev := range serverEvents {
switch ev.Kind {
case EventReceivedReadKeys:
serverReceived0RTTKeys = true
}
}
require.True(t, serverReceived0RTTKeys)
require.True(t, server.ConnectionState().DidResume)
require.True(t, client.ConnectionState().DidResume)
require.True(t, server.ConnectionState().Used0RTT)
require.True(t, client.ConnectionState().Used0RTT)
}
func Test0RTTRejectionOnTransportParametersChanged(t *testing.T) {
clientConf, serverConf := getTLSConfigs()
csc := newMockClientSessionCache()
clientConf.ClientSessionCache = csc
const initialMaxData protocol.ByteCount = 1337
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
utils.NewRTTStats(), utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
true,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
select {
case <-csc.puts:
case <-time.After(time.Second):
t.Fatal("didn't receive a session ticket")
}
require.False(t, server.ConnectionState().DidResume)
require.False(t, client.ConnectionState().DidResume)
clientRTTStats := utils.NewRTTStats()
client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
t,
clientConf, serverConf,
clientRTTStats, utils.NewRTTStats(),
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
true,
)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
var tp *wire.TransportParameters
var clientReceived0RTTKeys bool
for _, ev := range clientEvents {
switch ev.Kind {
case EventRestoredTransportParameters:
tp = ev.TransportParameters
case EventReceivedReadKeys:
clientReceived0RTTKeys = true
}
}
require.True(t, clientReceived0RTTKeys)
require.NotNil(t, tp)
require.Equal(t, initialMaxData, tp.InitialMaxData)
require.True(t, server.ConnectionState().DidResume)
require.True(t, client.ConnectionState().DidResume)
require.False(t, server.ConnectionState().Used0RTT)
require.False(t, client.ConnectionState().Used0RTT)
}