move dialing logic from the client into the Transport (#4859)

This commit is contained in:
Marten Seemann
2025-01-14 00:40:20 -08:00
committed by GitHub
parent fbbc3c9e30
commit 62a94758e6
5 changed files with 362 additions and 539 deletions

View File

@@ -13,6 +13,7 @@ import (
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
@@ -455,3 +456,162 @@ func TestTransportSetTLSConfigServerName(t *testing.T) {
})
}
}
func TestTransportDial(t *testing.T) {
t.Run("regular", func(t *testing.T) {
testTransportDial(t, false)
})
t.Run("early", func(t *testing.T) {
testTransportDial(t, true)
})
}
func testTransportDial(t *testing.T, early bool) {
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
mockCtrl := gomock.NewController(t)
conn := NewMockQUICConn(mockCtrl)
handshakeChan := make(chan struct{})
if early {
conn.EXPECT().earlyConnReady().Return(handshakeChan)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
} else {
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
}
blockRun := make(chan struct{})
conn.EXPECT().run().DoAndReturn(func() error {
<-blockRun
return errors.New("done")
})
defer close(blockRun)
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) quicConn {
return conn
}
tr := &Transport{Conn: newUPDConnLocalhost(t)}
tr.init(true)
defer tr.Close()
errChan := make(chan error, 1)
go func() {
var err error
if early {
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
} else {
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
}
errChan <- err
}()
select {
case <-errChan:
t.Fatal("Dial shouldn't have returned")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
close(handshakeChan)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
}
// for test tear-down
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
}
func TestTransportDialingVersionNegotiation(t *testing.T) {
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
// connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
mockCtrl := gomock.NewController(t)
// runner := NewMockConnRunner(mockCtrl)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().Return(&errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789})
conn2 := NewMockQUICConn(mockCtrl)
conn2.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn2.EXPECT().run().Return(errors.New("test done"))
type connParams struct {
pn protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.Version
}
connChan := make(chan connParams, 2)
var counter int
newClientConnection = func(
_ context.Context,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
pn protocol.PacketNumber,
_ bool,
hasNegotiatedVersion bool,
_ *logging.ConnectionTracer,
_ utils.Logger,
v protocol.Version,
) quicConn {
connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v}
if counter == 0 {
counter++
return conn
}
return conn2
}
tr := &Transport{Conn: newUPDConnLocalhost(t)}
tr.init(true)
defer tr.Close()
_, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil)
require.EqualError(t, err, "test done")
select {
case params := <-connChan:
require.Zero(t, params.pn)
require.False(t, params.hasNegotiatedVersion)
require.Equal(t, protocol.Version1, params.version)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case params := <-connChan:
require.Equal(t, protocol.PacketNumber(109), params.pn)
require.True(t, params.hasNegotiatedVersion)
require.Equal(t, protocol.Version(789), params.version)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// for test tear down
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
}