fix race condition in tests when setting the key update interval (#5121)

* fix race conditions in tests when setting the key update interval

* remove test for running transports and handshakes
This commit is contained in:
Marten Seemann
2025-05-06 10:45:27 +08:00
committed by GitHub
parent 4c6aca6b43
commit cfc6c16f36
4 changed files with 20 additions and 34 deletions

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
@@ -15,9 +14,7 @@ import (
) )
func TestKeyUpdates(t *testing.T) { func TestKeyUpdates(t *testing.T) {
origKeyUpdateInterval := handshake.KeyUpdateInterval t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", "1") // update keys as frequently as possible
t.Cleanup(func() { handshake.KeyUpdateInterval = origKeyUpdateInterval })
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
var sentHeaders []*logging.ShortHeader var sentHeaders []*logging.ShortHeader
var receivedHeaders []*logging.ShortHeader var receivedHeaders []*logging.ShortHeader

View File

@@ -1,7 +1,6 @@
package self_test package self_test
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@@ -11,9 +10,7 @@ import (
"math/rand/v2" "math/rand/v2"
"net" "net"
"os" "os"
"runtime/pprof"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -161,12 +158,6 @@ func newUDPConnLocalhost(t testing.TB) *net.UDPConn {
return conn return conn
} }
func areTransportsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
var versionParam string var versionParam string
flag.StringVar(&versionParam, "version", "1", "QUIC version") flag.StringVar(&versionParam, "version", "1", "QUIC version")
@@ -184,15 +175,7 @@ func TestMain(m *testing.M) {
} }
fmt.Printf("using QUIC version: %s\n", version) fmt.Printf("using QUIC version: %s\n", version)
status := m.Run() os.Exit(m.Run())
if status != 0 {
os.Exit(status)
}
if areTransportsRunning() {
fmt.Println("stray transport goroutines found")
os.Exit(1)
}
os.Exit(status)
} }
func scaleDuration(d time.Duration) time.Duration { func scaleDuration(d time.Duration) time.Duration {

View File

@@ -6,6 +6,9 @@ import (
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os"
"strconv"
"testing"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@@ -14,9 +17,15 @@ import (
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
) )
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. func keyUpdateInterval() uint64 {
// It's a package-level variable to allow modifying it for testing purposes. // Reparsing the environment variable is not very performant, but it's only done in tests.
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval if testing.Testing() {
if v, err := strconv.ParseUint(os.Getenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL"), 10, 64); err == nil {
return v
}
}
return protocol.KeyUpdateInterval
}
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
// It's a package-level variable to allow modifying it for testing purposes. // It's a package-level variable to allow modifying it for testing purposes.
@@ -293,11 +302,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
return true return true
} }
} }
if a.numRcvdWithCurrentKey >= KeyUpdateInterval { if a.numRcvdWithCurrentKey >= keyUpdateInterval() {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true return true
} }
if a.numSentWithCurrentKey >= KeyUpdateInterval { if a.numSentWithCurrentKey >= keyUpdateInterval() {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true return true
} }

View File

@@ -330,15 +330,12 @@ func TestRejectFrequentKeyUpdates(t *testing.T) {
} }
func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) { func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) {
origKeyUpdateInterval := KeyUpdateInterval t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", fmt.Sprintf("%d", keyUpdateInterval))
origFirstKeyUpdateInterval := FirstKeyUpdateInterval origFirstKeyUpdateInterval := FirstKeyUpdateInterval
KeyUpdateInterval = keyUpdateInterval
FirstKeyUpdateInterval = firstKeyUpdateInterval FirstKeyUpdateInterval = firstKeyUpdateInterval
t.Cleanup(func() { t.Cleanup(func() { FirstKeyUpdateInterval = origFirstKeyUpdateInterval })
KeyUpdateInterval = origKeyUpdateInterval
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
})
} }
func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) { func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
@@ -382,7 +379,7 @@ func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) { func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) {
const firstKeyUpdateInterval = 5 const firstKeyUpdateInterval = 5
setKeyUpdateIntervals(t, firstKeyUpdateInterval, KeyUpdateInterval) setKeyUpdateIntervals(t, firstKeyUpdateInterval, protocol.KeyUpdateInterval)
_, server, serverTracer := setupEndpoints(t, &utils.RTTStats{}) _, server, serverTracer := setupEndpoints(t, &utils.RTTStats{})
server.SetHandshakeConfirmed() server.SetHandshakeConfirmed()