fix connection ID length integration test (#4757)

This commit is contained in:
Marten Seemann
2024-12-18 12:25:09 +08:00
committed by GitHub
parent 3886be4e95
commit 9083be4322

View File

@@ -7,9 +7,13 @@ import (
"io" "io"
mrand "math/rand" mrand "math/rand"
"testing" "testing"
"time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -57,12 +61,12 @@ func testTransferWithConnectionIDs(
if serverConnIDGenerator != nil { if serverConnIDGenerator != nil {
t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen()) t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen())
} else { } else {
t.Logf("using %d byte connection ID for the server", serverConnIDLen) t.Logf("issuing %d byte connection ID from the server", serverConnIDLen)
} }
if clientConnIDGenerator != nil { if clientConnIDGenerator != nil {
t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen()) t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen())
} else { } else {
t.Logf("using %d byte connection ID for the client", clientConnIDLen) t.Logf("issuing %d byte connection ID from the client", clientConnIDLen)
} }
// setup server // setup server
@@ -71,43 +75,79 @@ func testTransferWithConnectionIDs(
ConnectionIDLength: serverConnIDLen, ConnectionIDLength: serverConnIDLen,
ConnectionIDGenerator: serverConnIDGenerator, ConnectionIDGenerator: serverConnIDGenerator,
} }
t.Cleanup(func() { serverTr.Close() }) defer serverTr.Close()
addTracer(serverTr) addTracer(serverTr)
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil)) serverCounter, serverTracer := newPacketTracer()
ln, err := serverTr.Listen(
getTLSConfig(),
getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
},
}),
)
require.NoError(t, err) require.NoError(t, err)
// setup client // setup client
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
var conn quic.Connection
clientCounter, clientTracer := newPacketTracer()
clientQUICConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
},
})
if clientConnIDGenerator == nil && clientConnIDLen == 0 {
conn, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), clientQUICConf)
require.NoError(t, err)
} else {
clientTr := &quic.Transport{ clientTr := &quic.Transport{
Conn: newUPDConnLocalhost(t), Conn: newUPDConnLocalhost(t),
ConnectionIDLength: clientConnIDLen, ConnectionIDLength: clientConnIDLen,
ConnectionIDGenerator: clientConnIDGenerator, ConnectionIDGenerator: clientConnIDGenerator,
} }
t.Cleanup(func() { clientTr.Close() }) defer clientTr.Close()
addTracer(clientTr) addTracer(clientTr)
conn, err = clientTr.Dial(ctx, ln.Addr(), getTLSClientConfig(), clientQUICConf)
cl, err := clientTr.Dial(
context.Background(),
ln.Addr(),
getTLSClientConfig(),
getQuicConfig(nil),
)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { cl.CloseWithError(0, "") }) }
serverConn, err := ln.Accept(context.Background()) serverConn, err := ln.Accept(context.Background())
require.NoError(t, err) require.NoError(t, err)
serverStr, err := serverConn.OpenStream() serverStr, err := serverConn.OpenStream()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
go func() { go func() {
serverStr.Write(PRData) serverStr.Write(PRData)
serverStr.Close() serverStr.Close()
}() }()
str, err := cl.AcceptStream(context.Background()) str, err := conn.AcceptStream(context.Background())
require.NoError(t, err) require.NoError(t, err)
data, err := io.ReadAll(str) data, err := io.ReadAll(str)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, PRData, data) require.Equal(t, PRData, data)
conn.CloseWithError(0, "")
serverConn.CloseWithError(0, "")
for _, p := range serverCounter.getRcvdShortHeaderPackets() {
expectedLen := serverConnIDLen
if serverConnIDGenerator != nil {
expectedLen = serverConnIDGenerator.ConnectionIDLen()
}
if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "server conn length mismatch") {
break
}
}
for _, p := range clientCounter.getRcvdShortHeaderPackets() {
expectedLen := clientConnIDLen
if clientConnIDGenerator != nil {
expectedLen = clientConnIDGenerator.ConnectionIDLen()
}
if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "client conn length mismatch") {
break
}
}
} }