forked from quic-go/quic-go
make the dependency-injected dialAddr in h2quic.client a global variable
It's only used for testing, so there's no need to have in each h2quic.client instance.
This commit is contained in:
@@ -24,11 +24,12 @@ type roundTripperOpts struct {
|
|||||||
DisableCompression bool
|
DisableCompression bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dialAddr = quic.DialAddr
|
||||||
|
|
||||||
// client is a HTTP2 client doing QUIC requests
|
// client is a HTTP2 client doing QUIC requests
|
||||||
type client struct {
|
type client struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error)
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *quic.Config
|
config *quic.Config
|
||||||
opts *roundTripperOpts
|
opts *roundTripperOpts
|
||||||
@@ -52,7 +53,6 @@ var _ http.RoundTripper = &client{}
|
|||||||
// newClient creates a new client
|
// newClient creates a new client
|
||||||
func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client {
|
func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client {
|
||||||
return &client{
|
return &client{
|
||||||
dialAddr: quic.DialAddr,
|
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
encryptionLevel: protocol.EncryptionUnencrypted,
|
||||||
@@ -69,7 +69,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
|
|||||||
// dial dials the connection
|
// dial dials the connection
|
||||||
func (c *client) dial() error {
|
func (c *client) dial() error {
|
||||||
var err error
|
var err error
|
||||||
c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config)
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ var _ = Describe("Client", func() {
|
|||||||
session *mockSession
|
session *mockSession
|
||||||
headerStream *mockStream
|
headerStream *mockStream
|
||||||
req *http.Request
|
req *http.Request
|
||||||
|
origDialAddr = dialAddr
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
origDialAddr = dialAddr
|
||||||
hostname := "quic.clemente.io:1337"
|
hostname := "quic.clemente.io:1337"
|
||||||
client = newClient(nil, hostname, &roundTripperOpts{})
|
client = newClient(nil, hostname, &roundTripperOpts{})
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
@@ -42,6 +44,10 @@ var _ = Describe("Client", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
dialAddr = origDialAddr
|
||||||
|
})
|
||||||
|
|
||||||
It("saves the TLS config", func() {
|
It("saves the TLS config", func() {
|
||||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||||
client = newClient(tlsConf, "", &roundTripperOpts{})
|
client = newClient(tlsConf, "", &roundTripperOpts{})
|
||||||
@@ -56,7 +62,7 @@ var _ = Describe("Client", func() {
|
|||||||
It("dials", func(done Done) {
|
It("dials", func(done Done) {
|
||||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
close(headerStream.unblockRead)
|
close(headerStream.unblockRead)
|
||||||
@@ -68,7 +74,7 @@ var _ = Describe("Client", func() {
|
|||||||
It("errors when dialing fails", func() {
|
It("errors when dialing fails", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTrip(req)
|
||||||
@@ -78,7 +84,7 @@ var _ = Describe("Client", func() {
|
|||||||
It("errors if the header stream has the wrong stream ID", func() {
|
It("errors if the header stream has the wrong stream ID", func() {
|
||||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||||
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
|
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTrip(req)
|
||||||
@@ -89,7 +95,7 @@ var _ = Describe("Client", func() {
|
|||||||
testErr := errors.New("you shall not pass")
|
testErr := errors.New("you shall not pass")
|
||||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||||
session.streamOpenErr = testErr
|
session.streamOpenErr = testErr
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTrip(req)
|
||||||
@@ -98,7 +104,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
It("returns a request when dial fails", func() {
|
It("returns a request when dial fails", func() {
|
||||||
testErr := errors.New("dial error")
|
testErr := errors.New("dial error")
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
@@ -140,7 +146,7 @@ var _ = Describe("Client", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
var err error
|
var err error
|
||||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
dataStream = newMockStream(5)
|
dataStream = newMockStream(5)
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ package h2quic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@@ -54,15 +57,45 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reuses existing clients", func() {
|
Context("dialing hosts", func() {
|
||||||
rt.clients = make(map[string]http.RoundTripper)
|
origDialAddr := dialAddr
|
||||||
rt.clients["www.example.org:443"] = &mockRoundTripper{}
|
streamOpenErr := errors.New("error opening stream")
|
||||||
rsp, err := rt.RoundTrip(req1)
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
origDialAddr = dialAddr
|
||||||
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||||
|
// return an error when trying to open a stream
|
||||||
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||||
|
return &mockSession{streamOpenErr: streamOpenErr}, nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
dialAddr = origDialAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
It("creates new clients", func() {
|
||||||
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rsp.Request).To(Equal(req1))
|
_, err = rt.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError(streamOpenErr))
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("reuses existing clients", func() {
|
||||||
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = rt.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError(streamOpenErr))
|
||||||
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
|
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = rt.RoundTrip(req2)
|
||||||
|
Expect(err).To(MatchError(streamOpenErr))
|
||||||
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("validating request", func() {
|
Context("validating request", func() {
|
||||||
It("rejects plain HTTP requests", func() {
|
It("rejects plain HTTP requests", func() {
|
||||||
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user