forked from quic-go/quic-go
move dialing logic from the client into the Transport (#4859)
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
@@ -455,3 +456,162 @@ func TestTransportSetTLSConfigServerName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportDial(t *testing.T) {
|
||||
t.Run("regular", func(t *testing.T) {
|
||||
testTransportDial(t, false)
|
||||
})
|
||||
|
||||
t.Run("early", func(t *testing.T) {
|
||||
testTransportDial(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testTransportDial(t *testing.T, early bool) {
|
||||
originalClientConnConstructor := newClientConnection
|
||||
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
handshakeChan := make(chan struct{})
|
||||
if early {
|
||||
conn.EXPECT().earlyConnReady().Return(handshakeChan)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
} else {
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
}
|
||||
blockRun := make(chan struct{})
|
||||
conn.EXPECT().run().DoAndReturn(func() error {
|
||||
<-blockRun
|
||||
return errors.New("done")
|
||||
})
|
||||
defer close(blockRun)
|
||||
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.Version,
|
||||
) quicConn {
|
||||
return conn
|
||||
}
|
||||
|
||||
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
var err error
|
||||
if early {
|
||||
_, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
|
||||
} else {
|
||||
_, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-errChan:
|
||||
t.Fatal("Dial shouldn't have returned")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
|
||||
close(handshakeChan)
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
// for test tear-down
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestTransportDialingVersionNegotiation(t *testing.T) {
|
||||
originalClientConnConstructor := newClientConnection
|
||||
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
|
||||
|
||||
// connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
mockCtrl := gomock.NewController(t)
|
||||
// runner := NewMockConnRunner(mockCtrl)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().run().Return(&errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789})
|
||||
|
||||
conn2 := NewMockQUICConn(mockCtrl)
|
||||
conn2.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn2.EXPECT().run().Return(errors.New("test done"))
|
||||
|
||||
type connParams struct {
|
||||
pn protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.Version
|
||||
}
|
||||
|
||||
connChan := make(chan connParams, 2)
|
||||
var counter int
|
||||
newClientConnection = func(
|
||||
_ context.Context,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *statelessResetter,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
_ bool,
|
||||
hasNegotiatedVersion bool,
|
||||
_ *logging.ConnectionTracer,
|
||||
_ utils.Logger,
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v}
|
||||
if counter == 0 {
|
||||
counter++
|
||||
return conn
|
||||
}
|
||||
return conn2
|
||||
}
|
||||
|
||||
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
|
||||
_, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil)
|
||||
require.EqualError(t, err, "test done")
|
||||
|
||||
select {
|
||||
case params := <-connChan:
|
||||
require.Zero(t, params.pn)
|
||||
require.False(t, params.hasNegotiatedVersion)
|
||||
require.Equal(t, protocol.Version1, params.version)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
select {
|
||||
case params := <-connChan:
|
||||
require.Equal(t, protocol.PacketNumber(109), params.pn)
|
||||
require.True(t, params.hasNegotiatedVersion)
|
||||
require.Equal(t, protocol.Version(789), params.version)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// for test tear down
|
||||
conn.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
conn2.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user