forked from quic-go/quic-go
fix race condition in tests when setting the key update interval (#5121)
* fix race conditions in tests when setting the key update interval * remove test for running transports and handshakes
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
"github.com/quic-go/quic-go/internal/handshake"
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/logging"
|
"github.com/quic-go/quic-go/logging"
|
||||||
|
|
||||||
@@ -15,9 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestKeyUpdates(t *testing.T) {
|
func TestKeyUpdates(t *testing.T) {
|
||||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", "1") // update keys as frequently as possible
|
||||||
t.Cleanup(func() { handshake.KeyUpdateInterval = origKeyUpdateInterval })
|
|
||||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
|
||||||
|
|
||||||
var sentHeaders []*logging.ShortHeader
|
var sentHeaders []*logging.ShortHeader
|
||||||
var receivedHeaders []*logging.ShortHeader
|
var receivedHeaders []*logging.ShortHeader
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package self_test
|
package self_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
@@ -11,9 +10,7 @@ import (
|
|||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime/pprof"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -161,12 +158,6 @@ func newUDPConnLocalhost(t testing.TB) *net.UDPConn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func areTransportsRunning() bool {
|
|
||||||
var b bytes.Buffer
|
|
||||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
||||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
var versionParam string
|
var versionParam string
|
||||||
flag.StringVar(&versionParam, "version", "1", "QUIC version")
|
flag.StringVar(&versionParam, "version", "1", "QUIC version")
|
||||||
@@ -184,15 +175,7 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
fmt.Printf("using QUIC version: %s\n", version)
|
fmt.Printf("using QUIC version: %s\n", version)
|
||||||
|
|
||||||
status := m.Run()
|
os.Exit(m.Run())
|
||||||
if status != 0 {
|
|
||||||
os.Exit(status)
|
|
||||||
}
|
|
||||||
if areTransportsRunning() {
|
|
||||||
fmt.Println("stray transport goroutines found")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
os.Exit(status)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func scaleDuration(d time.Duration) time.Duration {
|
func scaleDuration(d time.Duration) time.Duration {
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
@@ -14,9 +17,15 @@ import (
|
|||||||
"github.com/quic-go/quic-go/logging"
|
"github.com/quic-go/quic-go/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
|
func keyUpdateInterval() uint64 {
|
||||||
// It's a package-level variable to allow modifying it for testing purposes.
|
// Reparsing the environment variable is not very performant, but it's only done in tests.
|
||||||
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
if testing.Testing() {
|
||||||
|
if v, err := strconv.ParseUint(os.Getenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL"), 10, 64); err == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return protocol.KeyUpdateInterval
|
||||||
|
}
|
||||||
|
|
||||||
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
||||||
// It's a package-level variable to allow modifying it for testing purposes.
|
// It's a package-level variable to allow modifying it for testing purposes.
|
||||||
@@ -293,11 +302,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
|
if a.numRcvdWithCurrentKey >= keyUpdateInterval() {
|
||||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if a.numSentWithCurrentKey >= KeyUpdateInterval {
|
if a.numSentWithCurrentKey >= keyUpdateInterval() {
|
||||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -330,15 +330,12 @@ func TestRejectFrequentKeyUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) {
|
func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) {
|
||||||
origKeyUpdateInterval := KeyUpdateInterval
|
t.Setenv("QUIC_GO_TEST_KEY_UPDATE_INTERVAL", fmt.Sprintf("%d", keyUpdateInterval))
|
||||||
|
|
||||||
origFirstKeyUpdateInterval := FirstKeyUpdateInterval
|
origFirstKeyUpdateInterval := FirstKeyUpdateInterval
|
||||||
KeyUpdateInterval = keyUpdateInterval
|
|
||||||
FirstKeyUpdateInterval = firstKeyUpdateInterval
|
FirstKeyUpdateInterval = firstKeyUpdateInterval
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() { FirstKeyUpdateInterval = origFirstKeyUpdateInterval })
|
||||||
KeyUpdateInterval = origKeyUpdateInterval
|
|
||||||
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
|
func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
|
||||||
@@ -382,7 +379,7 @@ func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
|
|||||||
|
|
||||||
func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) {
|
func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) {
|
||||||
const firstKeyUpdateInterval = 5
|
const firstKeyUpdateInterval = 5
|
||||||
setKeyUpdateIntervals(t, firstKeyUpdateInterval, KeyUpdateInterval)
|
setKeyUpdateIntervals(t, firstKeyUpdateInterval, protocol.KeyUpdateInterval)
|
||||||
|
|
||||||
_, server, serverTracer := setupEndpoints(t, &utils.RTTStats{})
|
_, server, serverTracer := setupEndpoints(t, &utils.RTTStats{})
|
||||||
server.SetHandshakeConfirmed()
|
server.SetHandshakeConfirmed()
|
||||||
|
|||||||
Reference in New Issue
Block a user