forked from quic-go/quic-go
fix connection ID length integration test (#4757)
This commit is contained in:
@@ -7,9 +7,13 @@ import (
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -57,12 +61,12 @@ func testTransferWithConnectionIDs(
|
||||
if serverConnIDGenerator != nil {
|
||||
t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen())
|
||||
} else {
|
||||
t.Logf("using %d byte connection ID for the server", serverConnIDLen)
|
||||
t.Logf("issuing %d byte connection ID from the server", serverConnIDLen)
|
||||
}
|
||||
if clientConnIDGenerator != nil {
|
||||
t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen())
|
||||
} else {
|
||||
t.Logf("using %d byte connection ID for the client", clientConnIDLen)
|
||||
t.Logf("issuing %d byte connection ID from the client", clientConnIDLen)
|
||||
}
|
||||
|
||||
// setup server
|
||||
@@ -71,43 +75,79 @@ func testTransferWithConnectionIDs(
|
||||
ConnectionIDLength: serverConnIDLen,
|
||||
ConnectionIDGenerator: serverConnIDGenerator,
|
||||
}
|
||||
t.Cleanup(func() { serverTr.Close() })
|
||||
defer serverTr.Close()
|
||||
addTracer(serverTr)
|
||||
ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
serverCounter, serverTracer := newPacketTracer()
|
||||
ln, err := serverTr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
},
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// setup client
|
||||
clientTr := &quic.Transport{
|
||||
Conn: newUPDConnLocalhost(t),
|
||||
ConnectionIDLength: clientConnIDLen,
|
||||
ConnectionIDGenerator: clientConnIDGenerator,
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
var conn quic.Connection
|
||||
clientCounter, clientTracer := newPacketTracer()
|
||||
clientQUICConf := getQuicConfig(&quic.Config{
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
},
|
||||
})
|
||||
if clientConnIDGenerator == nil && clientConnIDLen == 0 {
|
||||
conn, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), clientQUICConf)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
clientTr := &quic.Transport{
|
||||
Conn: newUPDConnLocalhost(t),
|
||||
ConnectionIDLength: clientConnIDLen,
|
||||
ConnectionIDGenerator: clientConnIDGenerator,
|
||||
}
|
||||
defer clientTr.Close()
|
||||
addTracer(clientTr)
|
||||
conn, err = clientTr.Dial(ctx, ln.Addr(), getTLSClientConfig(), clientQUICConf)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
t.Cleanup(func() { clientTr.Close() })
|
||||
addTracer(clientTr)
|
||||
|
||||
cl, err := clientTr.Dial(
|
||||
context.Background(),
|
||||
ln.Addr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { cl.CloseWithError(0, "") })
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
serverStr, err := serverConn.OpenStream()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { serverConn.CloseWithError(0, "") })
|
||||
|
||||
go func() {
|
||||
serverStr.Write(PRData)
|
||||
serverStr.Close()
|
||||
}()
|
||||
|
||||
str, err := cl.AcceptStream(context.Background())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data, err := io.ReadAll(str)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PRData, data)
|
||||
|
||||
conn.CloseWithError(0, "")
|
||||
serverConn.CloseWithError(0, "")
|
||||
|
||||
for _, p := range serverCounter.getRcvdShortHeaderPackets() {
|
||||
expectedLen := serverConnIDLen
|
||||
if serverConnIDGenerator != nil {
|
||||
expectedLen = serverConnIDGenerator.ConnectionIDLen()
|
||||
}
|
||||
if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "server conn length mismatch") {
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, p := range clientCounter.getRcvdShortHeaderPackets() {
|
||||
expectedLen := clientConnIDLen
|
||||
if clientConnIDGenerator != nil {
|
||||
expectedLen = clientConnIDGenerator.ConnectionIDLen()
|
||||
}
|
||||
if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "client conn length mismatch") {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user