forked from quic-go/quic-go
fix flaky TestTransportReplaceWithClosed (#5245)
This commit is contained in:
@@ -5,8 +5,10 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -598,6 +600,8 @@ func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTransportReplaceWithClosed(t *testing.T) {
|
||||
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
|
||||
|
||||
t.Run("local", func(t *testing.T) {
|
||||
testTransportReplaceWithClosed(t, true)
|
||||
})
|
||||
@@ -616,7 +620,7 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) {
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
dur := scaleDuration(10 * time.Millisecond)
|
||||
dur := scaleDuration(20 * time.Millisecond)
|
||||
|
||||
var closePacket []byte
|
||||
if local {
|
||||
@@ -627,7 +631,6 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) {
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
m := (*packetHandlerMap)(tr)
|
||||
require.True(t, m.Add(connID, handler))
|
||||
start := time.Now()
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur)
|
||||
|
||||
p := make([]byte, 100)
|
||||
@@ -635,38 +638,64 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) {
|
||||
copy(p[1:], connID.Bytes())
|
||||
|
||||
conn := newUDPConnLocalhost(t)
|
||||
var sent int
|
||||
for now := range time.NewTicker(dur / 20).C {
|
||||
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
sent++
|
||||
if now.After(start.Add(dur / 2)) {
|
||||
break
|
||||
var sent atomic.Int64
|
||||
errChan := make(chan error, 1)
|
||||
stopSending := make(chan struct{})
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
ticker := time.NewTicker(dur / 50)
|
||||
timeout := time.NewTimer(scaleDuration(time.Second))
|
||||
for {
|
||||
select {
|
||||
case <-stopSending:
|
||||
return
|
||||
case <-timeout.C:
|
||||
errChan <- errors.New("timeout")
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
sent.Add(1)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
|
||||
for i := 0; i*i < sent; i++ {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
b := make([]byte, 100)
|
||||
if local {
|
||||
n, _, err := conn.ReadFrom(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foobar"), b[:n])
|
||||
}
|
||||
}
|
||||
// Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet.
|
||||
// Retry a few times, since the connection is deleted from the map on a timer.
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.WriteTo(p, tr.Conn.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
conn.SetReadDeadline(time.Now().Add(dur / 4))
|
||||
var received int
|
||||
conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second)))
|
||||
for {
|
||||
b := make([]byte, 100)
|
||||
n, _, err := conn.ReadFrom(b)
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) || bytes.Equal(b[:n], []byte("foobar")) {
|
||||
return false
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
|
||||
return true
|
||||
}, scaleDuration(200*time.Millisecond), scaleDuration(10*time.Millisecond))
|
||||
// at some point, the connection is cleaned up, and we'll receive a stateless reset
|
||||
if !bytes.Equal(b[:n], []byte("foobar")) {
|
||||
require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
|
||||
close(stopSending) // stop sending packets
|
||||
break
|
||||
}
|
||||
received++
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
numSent := sent.Load()
|
||||
if !local {
|
||||
require.Zero(t, received)
|
||||
t.Logf("sent %d packets", numSent)
|
||||
return
|
||||
}
|
||||
t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received)
|
||||
// timer resolution on Windows is terrible
|
||||
if runtime.GOOS != "windows" {
|
||||
require.GreaterOrEqual(t, numSent, int64(8))
|
||||
}
|
||||
require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent)))))
|
||||
require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent)))))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user