forked from quic-go/quic-go
@@ -1,6 +1,7 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -47,7 +48,7 @@ type Client struct {
|
||||
var _ h2quicClient = &Client{}
|
||||
|
||||
// NewClient creates a new client
|
||||
func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) {
|
||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
||||
c := &Client{
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
@@ -57,7 +58,7 @@ func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) {
|
||||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||
|
||||
var err error
|
||||
c.client, err = quic.NewClient(c.hostname, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ var _ = Describe("Client", func() {
|
||||
var err error
|
||||
quicTransport = &QuicRoundTripper{}
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client, err = NewClient(quicTransport, hostname)
|
||||
client, err = NewClient(quicTransport, nil, hostname)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
@@ -68,7 +68,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
@@ -192,7 +192,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -28,6 +29,10 @@ type QuicRoundTripper struct {
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
clients map[string]h2quicClient
|
||||
}
|
||||
|
||||
@@ -88,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
var err error
|
||||
client, err = NewClient(r, hostname)
|
||||
client, err = NewClient(r, r.TLSClientConfig, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -57,7 +57,8 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("reuses existing clients", func() {
|
||||
rt.clients = make(map[string]h2quicClient)
|
||||
rt.clients["www.example.org:443"] = &mockQuicRoundTripper{}
|
||||
rsp, _ := rt.RoundTrip(req1)
|
||||
rsp, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req1))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user