diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index c6cc97aa..e18821b7 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -7,9 +7,13 @@ import ( "io" mrand "math/rand" "testing" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" + + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -57,12 +61,12 @@ func testTransferWithConnectionIDs( if serverConnIDGenerator != nil { t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen()) } else { - t.Logf("using %d byte connection ID for the server", serverConnIDLen) + t.Logf("issuing %d byte connection ID from the server", serverConnIDLen) } if clientConnIDGenerator != nil { t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen()) } else { - t.Logf("using %d byte connection ID for the client", clientConnIDLen) + t.Logf("issuing %d byte connection ID from the client", clientConnIDLen) } // setup server @@ -71,43 +75,79 @@ func testTransferWithConnectionIDs( ConnectionIDLength: serverConnIDLen, ConnectionIDGenerator: serverConnIDGenerator, } - t.Cleanup(func() { serverTr.Close() }) + defer serverTr.Close() addTracer(serverTr) - ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil)) + serverCounter, serverTracer := newPacketTracer() + ln, err := serverTr.Listen( + getTLSConfig(), + getQuicConfig(&quic.Config{ + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return serverTracer + }, + }), + ) require.NoError(t, err) // setup client - clientTr := &quic.Transport{ - Conn: newUPDConnLocalhost(t), - ConnectionIDLength: clientConnIDLen, - ConnectionIDGenerator: clientConnIDGenerator, + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + var conn quic.Connection + clientCounter, clientTracer := newPacketTracer() + clientQUICConf := getQuicConfig(&quic.Config{ + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return clientTracer + }, + }) + if clientConnIDGenerator == nil && clientConnIDLen == 0 { + conn, err = quic.Dial(ctx, newUPDConnLocalhost(t), ln.Addr(), getTLSClientConfig(), clientQUICConf) + require.NoError(t, err) + } else { + clientTr := &quic.Transport{ + Conn: newUPDConnLocalhost(t), + ConnectionIDLength: clientConnIDLen, + ConnectionIDGenerator: clientConnIDGenerator, + } + defer clientTr.Close() + addTracer(clientTr) + conn, err = clientTr.Dial(ctx, ln.Addr(), getTLSClientConfig(), clientQUICConf) + require.NoError(t, err) } - t.Cleanup(func() { clientTr.Close() }) - addTracer(clientTr) - - cl, err := clientTr.Dial( - context.Background(), - ln.Addr(), - getTLSClientConfig(), - getQuicConfig(nil), - ) - require.NoError(t, err) - t.Cleanup(func() { cl.CloseWithError(0, "") }) serverConn, err := ln.Accept(context.Background()) require.NoError(t, err) serverStr, err := serverConn.OpenStream() require.NoError(t, err) - t.Cleanup(func() { serverConn.CloseWithError(0, "") }) go func() { serverStr.Write(PRData) serverStr.Close() }() - str, err := cl.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) + + conn.CloseWithError(0, "") + serverConn.CloseWithError(0, "") + + for _, p := range serverCounter.getRcvdShortHeaderPackets() { + expectedLen := serverConnIDLen + if serverConnIDGenerator != nil { + expectedLen = serverConnIDGenerator.ConnectionIDLen() + } + if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "server conn length mismatch") { + break + } + } + for _, p := range clientCounter.getRcvdShortHeaderPackets() { + expectedLen := clientConnIDLen + if clientConnIDGenerator != nil { + expectedLen = clientConnIDGenerator.ConnectionIDLen() + } + if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "client conn length mismatch") { + break + } + } }