forked from quic-go/quic-go
expose the quic.Config in the h2quic.RoundTripper
This commit is contained in:
@@ -50,19 +50,30 @@ type client struct {
|
||||
|
||||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDTruncation: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// newClient creates a new client
|
||||
func newClient(hostname string, tlsConfig *tls.Config, opts *roundTripperOpts) *client {
|
||||
func newClient(
|
||||
hostname string,
|
||||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
config = quicConfig
|
||||
}
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
||||
tlsConf: tlsConfig,
|
||||
config: &quic.Config{
|
||||
RequestConnectionIDTruncation: true,
|
||||
KeepAlive: true,
|
||||
},
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -31,7 +33,7 @@ var _ = Describe("Client", func() {
|
||||
BeforeEach(func() {
|
||||
origDialAddr = dialAddr
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client = newClient(hostname, nil, &roundTripperOpts{})
|
||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
session = &mockSession{}
|
||||
client.session = session
|
||||
@@ -50,17 +52,29 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("saves the TLS config", func() {
|
||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||
client = newClient("", tlsConf, &roundTripperOpts{})
|
||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
|
||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
||||
})
|
||||
|
||||
It("saves the QUIC config", func() {
|
||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
|
||||
Expect(client.config).To(Equal(quicConf))
|
||||
})
|
||||
|
||||
It("uses the default QUIC config if none is give", func() {
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
|
||||
Expect(client.config).ToNot(BeNil())
|
||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{})
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
||||
It("dials", func(done Done) {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{})
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
@@ -73,7 +87,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("errors when dialing fails", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{})
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
@@ -82,7 +96,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("errors if the header stream has the wrong stream ID", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{})
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
@@ -93,7 +107,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("errors if it can't open a stream", func() {
|
||||
testErr := errors.New("you shall not pass")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{})
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
session.streamOpenErr = testErr
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
@@ -252,7 +266,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{})
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
|
||||
"golang.org/x/net/lex/httplex"
|
||||
)
|
||||
|
||||
@@ -29,6 +31,10 @@ type RoundTripper struct {
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QuicConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
clients map[string]http.RoundTripper
|
||||
}
|
||||
|
||||
@@ -84,7 +90,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression})
|
||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -82,6 +83,18 @@ var _ = Describe("RoundTripper", func() {
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("uses the quic.Config, if provided", func() {
|
||||
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
||||
var receivedConfig *quic.Config
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||
receivedConfig = config
|
||||
return nil, errors.New("err")
|
||||
}
|
||||
rt.QuicConfig = config
|
||||
rt.RoundTrip(req1)
|
||||
Expect(receivedConfig).To(Equal(config))
|
||||
})
|
||||
|
||||
It("reuses existing clients", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Reference in New Issue
Block a user