forked from quic-go/quic-go
implement a h2quic client that can send H2 requests
This commit is contained in:
140
h2quic/client.go
Normal file
140
h2quic/client.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type quicClient interface {
|
||||
OpenStream(protocol.StreamID) (utils.Stream, error)
|
||||
Close() error
|
||||
Listen() error
|
||||
}
|
||||
|
||||
// Client is a HTTP2 client doing QUIC requests
|
||||
type Client struct {
|
||||
mutex sync.Mutex
|
||||
cryptoChangedCond sync.Cond
|
||||
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
|
||||
client quicClient
|
||||
headerStream utils.Stream
|
||||
highestOpenedStream protocol.StreamID
|
||||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
}
|
||||
|
||||
// NewClient creates a new client
|
||||
func NewClient(hostname string) (*Client, error) {
|
||||
c := &Client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
highestOpenedStream: 3,
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
}
|
||||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||
|
||||
var err error
|
||||
c.client, err = quic.NewClient(c.hostname, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go c.client.Listen()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) {
|
||||
utils.Debugf("Handling stream %d", stream.StreamID())
|
||||
}
|
||||
|
||||
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
|
||||
c.cryptoChangedCond.L.Lock()
|
||||
defer c.cryptoChangedCond.L.Unlock()
|
||||
|
||||
if isForwardSecure {
|
||||
c.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
utils.Debugf("is forward secure")
|
||||
} else {
|
||||
c.encryptionLevel = protocol.EncryptionSecure
|
||||
utils.Debugf("is secure")
|
||||
}
|
||||
c.cryptoChangedCond.Broadcast()
|
||||
}
|
||||
|
||||
func (c *Client) versionNegotiateCallback() error {
|
||||
var err error
|
||||
// once the version has been negotiated, open the header stream
|
||||
c.headerStream, err = c.client.OpenStream(3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do executes a request and returns a response
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
// TODO: add port to address, if it doesn't have one
|
||||
if req.URL.Scheme != "https" {
|
||||
return nil, errors.New("quic http2: unsupported scheme")
|
||||
}
|
||||
if authorityAddr("https", req.Host) != c.hostname {
|
||||
utils.Debugf("%s vs %s", req.Host, c.hostname)
|
||||
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.highestOpenedStream += 2
|
||||
dataStreamID := c.highestOpenedStream
|
||||
for c.encryptionLevel != protocol.EncryptionForwardSecure {
|
||||
c.cryptoChangedCond.Wait()
|
||||
}
|
||||
_, err := c.client.OpenStream(dataStreamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = c.requestWriter.WriteRequest(req, dataStreamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
// TODO: get the response
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
121
h2quic/client_test.go
Normal file
121
h2quic/client_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockQuicClient struct {
|
||||
streams map[protocol.StreamID]*mockStream
|
||||
}
|
||||
|
||||
func (m *mockQuicClient) Close() error { panic("not implemented") }
|
||||
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
||||
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
||||
_, ok := m.streams[id]
|
||||
if ok {
|
||||
panic("Stream already exists")
|
||||
}
|
||||
ms := &mockStream{id: id}
|
||||
m.streams[id] = ms
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func newMockQuicClient() *mockQuicClient {
|
||||
return &mockQuicClient{
|
||||
streams: make(map[protocol.StreamID]*mockStream),
|
||||
}
|
||||
}
|
||||
|
||||
var _ quicClient = &mockQuicClient{}
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var client *Client
|
||||
var qClient *mockQuicClient
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client, err = NewClient(hostname)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
client.client = qClient
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
||||
It("opens the header stream only after the version has been negotiated", func() {
|
||||
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
||||
err := client.versionNegotiateCallback()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.headerStream).ToNot(BeNil())
|
||||
Expect(client.headerStream.StreamID()).To(Equal(protocol.StreamID(3)))
|
||||
})
|
||||
|
||||
It("sets the correct crypto level", func() {
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.Unencrypted))
|
||||
client.cryptoChangeCallback(false)
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
client.cryptoChangeCallback(true)
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
|
||||
})
|
||||
|
||||
Context("Doing requests", func() {
|
||||
BeforeEach(func() {
|
||||
qClient.streams[3] = &mockStream{}
|
||||
client.requestWriter = newRequestWriter(qClient.streams[3])
|
||||
})
|
||||
|
||||
It("does a request", func(done Done) {
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go client.Do(req)
|
||||
Eventually(func() []byte { return qClient.streams[3].dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
Expect(client.highestOpenedStream).Should(Equal(protocol.StreamID(5)))
|
||||
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
||||
close(done)
|
||||
})
|
||||
|
||||
Context("validating the address", func() {
|
||||
It("refuses to do requests for the wrong host", func() {
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.Do(req)
|
||||
Expect(err).To(MatchError("h2quic Client BUG: Do called for the wrong client"))
|
||||
})
|
||||
|
||||
It("refuses to do plain HTTP requests", func() {
|
||||
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.Do(req)
|
||||
Expect(err).To(MatchError("quic http2: unsupported scheme"))
|
||||
})
|
||||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qClient.streams[3] = &mockStream{}
|
||||
client.requestWriter = newRequestWriter(qClient.streams[3])
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user