expose the quic.Config in the h2quic.RoundTripper

This commit is contained in:
Marten Seemann
2017-07-08 20:15:09 +08:00
parent abb9594af8
commit ee6ca8dfb4
4 changed files with 60 additions and 16 deletions

View File

@@ -50,19 +50,30 @@ type client struct {
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
}
// newClient creates a new client
func newClient(hostname string, tlsConfig *tls.Config, opts *roundTripperOpts) *client {
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
},
opts: opts,
headerErrored: make(chan struct{}),
config: config,
opts: opts,
headerErrored: make(chan struct{}),
}
}

View File

@@ -15,6 +15,8 @@ import (
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@@ -31,7 +33,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{})
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
client.session = session
@@ -50,17 +52,29 @@ var _ = Describe("Client", func() {
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{})
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{})
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func(done Done) {
client = newClient("localhost:1337", nil, &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@@ -73,7 +87,7 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
@@ -82,7 +96,7 @@ var _ = Describe("Client", func() {
})
It("errors if the header stream has the wrong stream ID", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@@ -93,7 +107,7 @@ var _ = Describe("Client", func() {
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@@ -252,7 +266,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) {
var err error
client = newClient("quic.clemente.io", nil, &roundTripperOpts{})
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())

View File

@@ -8,6 +8,8 @@ import (
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex"
)
@@ -29,6 +31,10 @@ type RoundTripper struct {
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]http.RoundTripper
}
@@ -84,7 +90,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
client, ok := r.clients[hostname]
if !ok {
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression})
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client
}
return client

View File

@@ -6,6 +6,7 @@ import (
"errors"
"io"
"net/http"
"time"
quic "github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo"
@@ -82,6 +83,18 @@ var _ = Describe("RoundTripper", func() {
Expect(rt.clients).To(HaveLen(1))
})
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
receivedConfig = config
return nil, errors.New("err")
}
rt.QuicConfig = config
rt.RoundTrip(req1)
Expect(receivedConfig).To(Equal(config))
})
It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())