protocol: use math/rand/v2 to generate greased versions (#5046)

This commit is contained in:
Marten Seemann
2025-04-14 20:30:04 +08:00
committed by GitHub
parent e3f1b7c410
commit 0894931c64
2 changed files with 32 additions and 12 deletions

View File

@@ -1,13 +1,12 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
mrand "math/rand/v2"
"sync"
"time"
"golang.org/x/exp/rand"
)
// Version is a version number as int
@@ -90,13 +89,22 @@ func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
var (
versionNegotiationMx sync.Mutex
versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
versionNegotiationRand mrand.Rand
)
func init() {
var seed [16]byte
rand.Read(seed[:])
versionNegotiationRand = *mrand.New(mrand.NewPCG(
binary.BigEndian.Uint64(seed[:8]),
binary.BigEndian.Uint64(seed[8:]),
))
}
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() Version {
var b [4]byte
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
binary.BigEndian.PutUint32(b[:], versionNegotiationRand.Uint32())
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
}
@@ -105,7 +113,7 @@ func generateReservedVersion() Version {
func GetGreasedVersions(supported []Version) []Version {
versionNegotiationMx.Lock()
defer versionNegotiationMx.Unlock()
randPos := rand.Intn(len(supported) + 1)
randPos := versionNegotiationRand.IntN(len(supported) + 1)
greased := make([]Version, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()

View File

@@ -1,6 +1,7 @@
package protocol
import (
"slices"
"testing"
"github.com/stretchr/testify/require"
@@ -95,21 +96,33 @@ func TestVersionSelection(t *testing.T) {
func isReservedVersion(v Version) bool { return v&0x0f0f0f0f == 0x0a0a0a0a }
func TestAddGreasedVersionToEmptySlice(t *testing.T) {
func TestVersionGreasing(t *testing.T) {
// adding to an empty slice
greased := GetGreasedVersions([]Version{})
require.Len(t, greased, 1)
require.True(t, isReservedVersion(greased[0]))
}
func TestAddGreasedVersion(t *testing.T) {
// make sure that the greased versions are distinct
var versions []Version
for range 20 {
versions = GetGreasedVersions(versions)
}
slices.Sort(versions)
for i, v := range versions {
require.True(t, isReservedVersion(v))
if i > 0 {
require.NotEqual(t, versions[i-1], v)
}
}
// adding it somewhere in a slice of supported versions
supported := []Version{10, 18, 29}
for _, v := range supported {
require.False(t, isReservedVersion(v))
}
var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int
for i := 0; i < 100; i++ {
for range 100 {
greased := GetGreasedVersions(supported)
require.Len(t, greased, 4)
@@ -129,7 +142,6 @@ func TestAddGreasedVersion(t *testing.T) {
j++
}
}
require.NotZero(t, greasedVersionFirst)
require.NotZero(t, greasedVersionLast)
require.NotZero(t, greasedVersionMiddle)