diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index ae72299e..d70c1a8e 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/logging" @@ -15,9 +14,7 @@ import ( ) func TestKeyUpdates(t *testing.T) { - origKeyUpdateInterval := handshake.KeyUpdateInterval - t.Cleanup(func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }) - handshake.KeyUpdateInterval = 1 // update keys as frequently as possible + t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", "1") // update keys as frequently as possible var sentHeaders []*logging.ShortHeader var receivedHeaders []*logging.ShortHeader diff --git a/integrationtests/self/self_test.go b/integrationtests/self/self_test.go index 7516cafc..6df7863b 100644 --- a/integrationtests/self/self_test.go +++ b/integrationtests/self/self_test.go @@ -1,7 +1,6 @@ package self_test import ( - "bytes" "context" "crypto/tls" "crypto/x509" @@ -11,9 +10,7 @@ import ( "math/rand/v2" "net" "os" - "runtime/pprof" "strconv" - "strings" "testing" "time" @@ -161,12 +158,6 @@ func newUDPConnLocalhost(t testing.TB) *net.UDPConn { return conn } -func areTransportsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*Transport).listen") -} - func TestMain(m *testing.M) { var versionParam string flag.StringVar(&versionParam, "version", "1", "QUIC version") @@ -184,15 +175,7 @@ func TestMain(m *testing.M) { } fmt.Printf("using QUIC version: %s\n", version) - status := m.Run() - if status != 0 { - os.Exit(status) - } - if areTransportsRunning() { - fmt.Println("stray transport goroutines found") - os.Exit(1) - } - os.Exit(status) + os.Exit(m.Run()) } func scaleDuration(d time.Duration) time.Duration { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index ceaa8047..0aca46fe 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -6,6 +6,9 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "os" + "strconv" + "testing" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -14,9 +17,15 @@ import ( "github.com/quic-go/quic-go/logging" ) -// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. -// It's a package-level variable to allow modifying it for testing purposes. -var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval +func keyUpdateInterval() uint64 { + // Reparsing the environment variable is not very performant, but it's only done in tests. + if testing.Testing() { + if v, err := strconv.ParseUint(os.Getenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL"), 10, 64); err == nil { + return v + } + } + return protocol.KeyUpdateInterval +} // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. // It's a package-level variable to allow modifying it for testing purposes. @@ -293,11 +302,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { return true } } - if a.numRcvdWithCurrentKey >= KeyUpdateInterval { + if a.numRcvdWithCurrentKey >= keyUpdateInterval() { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } - if a.numSentWithCurrentKey >= KeyUpdateInterval { + if a.numSentWithCurrentKey >= keyUpdateInterval() { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 2b4cefa9..36e0638e 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -330,15 +330,12 @@ func TestRejectFrequentKeyUpdates(t *testing.T) { } func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) { - origKeyUpdateInterval := KeyUpdateInterval + t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", fmt.Sprintf("%d", keyUpdateInterval)) + origFirstKeyUpdateInterval := FirstKeyUpdateInterval - KeyUpdateInterval = keyUpdateInterval FirstKeyUpdateInterval = firstKeyUpdateInterval - t.Cleanup(func() { - KeyUpdateInterval = origKeyUpdateInterval - FirstKeyUpdateInterval = origFirstKeyUpdateInterval - }) + t.Cleanup(func() { FirstKeyUpdateInterval = origFirstKeyUpdateInterval }) } func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) { @@ -382,7 +379,7 @@ func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) { func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) { const firstKeyUpdateInterval = 5 - setKeyUpdateIntervals(t, firstKeyUpdateInterval, KeyUpdateInterval) + setKeyUpdateIntervals(t, firstKeyUpdateInterval, protocol.KeyUpdateInterval) _, server, serverTracer := setupEndpoints(t, &utils.RTTStats{}) server.SetHandshakeConfirmed()