forked from quic-go/quic-go
also the context when dialing an address
This commit is contained in:
@@ -36,7 +36,7 @@ var defaultQuicConfig = &quic.Config{
|
|||||||
|
|
||||||
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||||
|
|
||||||
var dialAddr = quic.DialAddrEarly
|
var dialAddr = quic.DialAddrEarlyContext
|
||||||
|
|
||||||
type roundTripperOpts struct {
|
type roundTripperOpts struct {
|
||||||
DisableCompression bool
|
DisableCompression bool
|
||||||
@@ -103,7 +103,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||||||
if c.dialer != nil {
|
if c.dialer != nil {
|
||||||
c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config)
|
c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config)
|
||||||
} else {
|
} else {
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
c.session, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ var _ = Describe("Client", func() {
|
|||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
|
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
|
||||||
@@ -80,7 +80,7 @@ var _ = Describe("Client", func() {
|
|||||||
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
@@ -100,12 +100,8 @@ var _ = Describe("Client", func() {
|
|||||||
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(
|
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
||||||
hostname string,
|
Expect(host).To(Equal("localhost:1337"))
|
||||||
tlsConfP *tls.Config,
|
|
||||||
quicConfP *quic.Config,
|
|
||||||
) (quic.EarlySession, error) {
|
|
||||||
Expect(hostname).To(Equal("localhost:1337"))
|
|
||||||
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
||||||
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
||||||
@@ -145,7 +141,7 @@ var _ = Describe("Client", func() {
|
|||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
@@ -157,7 +153,7 @@ var _ = Describe("Client", func() {
|
|||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
@@ -182,7 +178,7 @@ var _ = Describe("Client", func() {
|
|||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
|
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
@@ -209,7 +205,7 @@ var _ = Describe("Client", func() {
|
|||||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
||||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -456,7 +452,7 @@ var _ = Describe("Client", func() {
|
|||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
session = mockquic.NewMockEarlySession(mockCtrl)
|
session = mockquic.NewMockEarlySession(mockCtrl)
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
// return an error when trying to open a stream
|
// return an error when trying to open a stream
|
||||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
It("uses the quic.Config, if provided", func() {
|
It("uses the quic.Config, if provided", func() {
|
||||||
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
|
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
|
||||||
var receivedConfig *quic.Config
|
var receivedConfig *quic.Config
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||||
receivedConfig = config
|
receivedConfig = config
|
||||||
return nil, errors.New("handshake error")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user