Merge pull request #722 from lucas-clemente/fix-677

expose the quic.Config in h2quic.Server and h2quic.RoundTripper
This commit is contained in:
Marten Seemann
2017-07-10 23:16:50 +08:00
committed by GitHub
6 changed files with 161 additions and 75 deletions

View File

@@ -24,14 +24,15 @@ type roundTripperOpts struct {
DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error)
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string
encryptionLevel protocol.EncryptionLevel
@@ -49,27 +50,37 @@ type client struct {
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
}
// newClient creates a new client
func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *client {
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
dialAddr: quic.DialAddr,
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
},
opts: opts,
headerErrored: make(chan struct{}),
config: config,
opts: opts,
headerErrored: make(chan struct{}),
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config)
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
if err != nil {
return err
}

View File

@@ -15,6 +15,8 @@ import (
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@@ -25,11 +27,13 @@ var _ = Describe("Client", func() {
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(nil, hostname, &roundTripperOpts{})
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
client.session = session
@@ -42,21 +46,37 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient(tlsConf, "", &roundTripperOpts{})
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func(done Done) {
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
@@ -67,8 +87,8 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
@@ -76,9 +96,9 @@ var _ = Describe("Client", func() {
})
It("errors if the header stream has the wrong stream ID", func() {
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
@@ -87,9 +107,9 @@ var _ = Describe("Client", func() {
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamOpenErr = testErr
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
@@ -98,7 +118,7 @@ var _ = Describe("Client", func() {
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
@@ -140,7 +160,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
var err error
client.encryptionLevel = protocol.EncryptionForwardSecure
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
@@ -246,7 +266,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) {
var err error
client = newClient(nil, "quic.clemente.io", &roundTripperOpts{})
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())

View File

@@ -8,6 +8,8 @@ import (
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex"
)
@@ -29,6 +31,10 @@ type RoundTripper struct {
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]http.RoundTripper
}
@@ -84,7 +90,7 @@ func (r *RoundTripper) getClient(hostname string) http.RoundTripper {
client, ok := r.clients[hostname]
if !ok {
client = newClient(r.TLSClientConfig, hostname, &roundTripperOpts{DisableCompression: r.DisableCompression})
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client
}
return client

View File

@@ -2,9 +2,13 @@ package h2quic
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net/http"
"time"
quic "github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@@ -54,13 +58,55 @@ var _ = Describe("RoundTripper", func() {
Expect(err).ToNot(HaveOccurred())
})
It("reuses existing clients", func() {
rt.clients = make(map[string]http.RoundTripper)
rt.clients["www.example.org:443"] = &mockRoundTripper{}
rsp, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req1))
Expect(rt.clients).To(HaveLen(1))
Context("dialing hosts", func() {
origDialAddr := dialAddr
streamOpenErr := errors.New("error opening stream")
BeforeEach(func() {
origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return &mockSession{streamOpenErr: streamOpenErr}, nil
}
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("creates new clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
receivedConfig = config
return nil, errors.New("err")
}
rt.QuicConfig = config
rt.RoundTrip(req1)
Expect(receivedConfig).To(Equal(config))
})
It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
})
Context("validating request", func() {

View File

@@ -29,10 +29,20 @@ type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections.
type Server struct {
*http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use
CloseAfterFirstRequest bool
@@ -83,16 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("ListenAndServe may only be called once")
}
config := quic.Config{
Versions: protocol.SupportedVersions,
}
var ln quic.Listener
var err error
if conn == nil {
ln, err = quic.ListenAddr(s.Addr, tlsConfig, &config)
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else {
ln, err = quic.Listen(conn, tlsConfig, &config)
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
}
if err != nil {
s.listenerMutex.Unlock()

View File

@@ -2,13 +2,12 @@ package h2quic
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"os"
"runtime"
"sync"
"syscall"
"time"
"golang.org/x/net/http2"
@@ -66,9 +65,10 @@ func (s *mockSession) WaitUntilClosed() { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (
s *Server
session *mockSession
dataStream *mockStream
s *Server
session *mockSession
dataStream *mockStream
origQuicListenAddr = quicListenAddr
)
BeforeEach(func() {
@@ -80,6 +80,11 @@ var _ = Describe("H2 server", func() {
dataStream = newMockStream(0)
close(dataStream.unblockRead)
session = &mockSession{dataStream: dataStream}
origQuicListenAddr = quicListenAddr
})
AfterEach(func() {
quicListenAddr = origQuicListenAddr
})
Context("handling requests", func() {
@@ -380,8 +385,7 @@ var _ = Describe("H2 server", func() {
})
AfterEach(func() {
err := s.Close()
Expect(err).NotTo(HaveOccurred())
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
@@ -399,8 +403,19 @@ var _ = Describe("H2 server", func() {
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
It("uses the quic.Config to start the quic server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
s.QuicConfig = conf
go s.ListenAndServe()
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
})
})
Context("ListenAndServeTLS", func() {
@@ -436,31 +451,13 @@ var _ = Describe("H2 server", func() {
Expect(err).NotTo(HaveOccurred())
})
It("at least errors in global ListenAndServeQUIC", func() {
// It's quite hard to test this, since we cannot properly shutdown the server
// once it's started. So, we open a socket on the same port before the test,
// so that ListenAndServeQUIC definitely fails. This way we know it at least
// created a socket on the proper address :)
const addr = "127.0.0.1:4826"
udpAddr, err := net.ResolveUDPAddr("udp", addr)
Expect(err).NotTo(HaveOccurred())
c, err := net.ListenUDP("udp", udpAddr)
Expect(err).NotTo(HaveOccurred())
defer c.Close()
fullpem, privkey := testdata.GetCertificatePaths()
err = ListenAndServeQUIC(addr, fullpem, privkey, nil)
// Check that it's an EADDRINUSE
Expect(err).ToNot(BeNil())
opErr, ok := err.(*net.OpError)
Expect(ok).To(BeTrue())
syscallErr, ok := opErr.Err.(*os.SyscallError)
Expect(ok).To(BeTrue())
if runtime.GOOS == "windows" {
// for some reason, Windows return a different error number, corresponding to an WSAEADDRINUSE error
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681391(v=vs.85).aspx
Expect(syscallErr.Err).To(Equal(syscall.Errno(0x2740)))
} else {
Expect(syscallErr.Err).To(MatchError(syscall.EADDRINUSE))
It("errors when listening fails", func() {
testErr := errors.New("listen error")
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
return nil, testErr
}
fullpem, privkey := testdata.GetCertificatePaths()
err := ListenAndServeQUIC("", fullpem, privkey, nil)
Expect(err).To(MatchError(testErr))
})
})