forked from quic-go/quic-go
* create a new type for crypto stream used for Initial data This currently the exact same implementation as the other streams, thus no functional change is expected. * handshake: implement a function to find the SNI and the ECH extension * move the SNI parsing logic to the quic package * implement splitting logic * generalize cutting logic * introduce QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING * improve testing
101 lines
2.7 KiB
Go
101 lines
2.7 KiB
Go
//go:build go1.24
|
|
|
|
package quic
|
|
|
|
import (
|
|
"fmt"
|
|
mrand "math/rand/v2"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func randomDomainName(length int) string {
|
|
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
|
b := make([]byte, length)
|
|
for i := range b {
|
|
if i > 0 && i < length-1 && mrand.IntN(5) == 0 && b[i-1] != '.' {
|
|
b[i] = '.'
|
|
} else {
|
|
b[i] = alphabet[mrand.IntN(len(alphabet))]
|
|
}
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func TestInitialCryptoStreamClientRandomizedSizes(t *testing.T) {
|
|
skipIfDisableScramblingEnvSet(t)
|
|
|
|
for i := range 100 {
|
|
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
|
|
var serverName string
|
|
if mrand.Int()%4 > 0 {
|
|
serverName = randomDomainName(6 + mrand.IntN(20))
|
|
}
|
|
var clientHello []byte
|
|
if serverName == "" || !strings.Contains(serverName, ".") || mrand.Int()%2 == 0 {
|
|
t.Logf("using a ClientHello without ECH, hostname: %q", serverName)
|
|
clientHello = getClientHello(t, serverName)
|
|
} else {
|
|
t.Logf("using a ClientHello with ECH, hostname: %q", serverName)
|
|
clientHello = getClientHelloWithECH(t, serverName)
|
|
}
|
|
testInitialCryptoStreamClientRandomizedSizes(t, clientHello, serverName)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testInitialCryptoStreamClientRandomizedSizes(t *testing.T, clientHello []byte, expectedServerName string) {
|
|
str := newInitialCryptoStream(true)
|
|
|
|
b := slices.Clone(clientHello)
|
|
for len(b) > 0 {
|
|
n := min(len(b), mrand.IntN(2*len(b)))
|
|
_, err := str.Write(b[:n])
|
|
require.NoError(t, err)
|
|
b = b[n:]
|
|
}
|
|
|
|
require.True(t, str.HasData())
|
|
_, err := str.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
|
|
segments := make(map[protocol.ByteCount][]byte)
|
|
|
|
var frames []*wire.CryptoFrame
|
|
for str.HasData() {
|
|
// fmt.Println("popping a frame")
|
|
var maxSize protocol.ByteCount
|
|
if mrand.Int()%4 == 0 {
|
|
maxSize = protocol.ByteCount(mrand.IntN(512) + 1)
|
|
} else {
|
|
maxSize = protocol.ByteCount(mrand.IntN(32) + 1)
|
|
}
|
|
f := str.PopCryptoFrame(maxSize)
|
|
if f == nil {
|
|
continue
|
|
}
|
|
frames = append(frames, f)
|
|
require.LessOrEqual(t, f.Length(protocol.Version1), maxSize)
|
|
}
|
|
t.Logf("received %d frames", len(frames))
|
|
|
|
for _, f := range frames {
|
|
t.Logf("offset %d: %d bytes", f.Offset, len(f.Data))
|
|
if expectedServerName != "" {
|
|
require.NotContainsf(t, string(f.Data), expectedServerName, "frame at offset %d contains the server name", f.Offset)
|
|
}
|
|
segments[f.Offset] = f.Data
|
|
}
|
|
|
|
reassembled := reassembleCryptoData(t, segments)
|
|
require.Equal(t, append(clientHello, []byte("foobar")...), reassembled)
|
|
if expectedServerName != "" {
|
|
require.Contains(t, string(reassembled), expectedServerName)
|
|
}
|
|
}
|