forked from quic-go/quic-go
handle the header stream in the h2quic client
This commit is contained in:
@@ -2,15 +2,19 @@ package h2quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
@@ -22,7 +26,7 @@ type quicClient interface {
|
||||
|
||||
// Client is a HTTP2 client doing QUIC requests
|
||||
type Client struct {
|
||||
mutex sync.Mutex
|
||||
mutex sync.RWMutex
|
||||
cryptoChangedCond sync.Cond
|
||||
|
||||
hostname string
|
||||
@@ -30,6 +34,7 @@ type Client struct {
|
||||
|
||||
client quicClient
|
||||
headerStream utils.Stream
|
||||
headerErr *qerr.QuicError
|
||||
highestOpenedStream protocol.StreamID
|
||||
requestWriter *requestWriter
|
||||
|
||||
@@ -81,9 +86,57 @@ func (c *Client) versionNegotiateCallback() error {
|
||||
return err
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleHeaderStream() {
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var lastStream protocol.StreamID
|
||||
|
||||
for {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame")
|
||||
break
|
||||
}
|
||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
||||
break
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
||||
break
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||
break
|
||||
}
|
||||
|
||||
rsp := &http.Response{}
|
||||
// TODO: fill in the right values
|
||||
headerChan <- rsp
|
||||
}
|
||||
|
||||
// stop all running request
|
||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||
c.mutex.Lock()
|
||||
for _, responseChan := range c.responses {
|
||||
responseChan <- nil
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Do executes a request and returns a response
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
// TODO: add port to address, if it doesn't have one
|
||||
@@ -101,6 +154,9 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
for c.encryptionLevel != protocol.EncryptionForwardSecure {
|
||||
c.cryptoChangedCond.Wait()
|
||||
}
|
||||
|
||||
hdrChan := make(chan *http.Response)
|
||||
c.responses[dataStreamID] = hdrChan
|
||||
_, err := c.client.OpenStream(dataStreamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -111,9 +167,20 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
// TODO: get the response
|
||||
var rsp *http.Response
|
||||
select {
|
||||
case rsp = <-hdrChan:
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStreamID)
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
// if an error occured on the header stream
|
||||
if rsp == nil {
|
||||
return nil, c.headerErr
|
||||
}
|
||||
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
@@ -3,7 +3,10 @@ package h2quic
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -35,8 +38,11 @@ func newMockQuicClient() *mockQuicClient {
|
||||
var _ quicClient = &mockQuicClient{}
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var client *Client
|
||||
var qClient *mockQuicClient
|
||||
var (
|
||||
client *Client
|
||||
qClient *mockQuicClient
|
||||
headerStream *mockStream
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
@@ -46,6 +52,11 @@ var _ = Describe("Client", func() {
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
client.client = qClient
|
||||
|
||||
headerStream = &mockStream{}
|
||||
qClient.streams[3] = headerStream
|
||||
client.headerStream = headerStream
|
||||
client.requestWriter = newRequestWriter(headerStream)
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
@@ -56,7 +67,11 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("opens the header stream only after the version has been negotiated", func() {
|
||||
// delete the headerStream openend in the BeforeEach
|
||||
client.headerStream = nil
|
||||
delete(qClient.streams, 3)
|
||||
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
||||
// now start the actual test
|
||||
err := client.versionNegotiateCallback()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.headerStream).ToNot(BeNil())
|
||||
@@ -73,18 +88,33 @@ var _ = Describe("Client", func() {
|
||||
|
||||
Context("Doing requests", func() {
|
||||
BeforeEach(func() {
|
||||
qClient.streams[3] = &mockStream{}
|
||||
client.requestWriter = newRequestWriter(qClient.streams[3])
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
})
|
||||
|
||||
It("does a request", func(done Done) {
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go client.Do(req)
|
||||
Eventually(func() []byte { return qClient.streams[3].dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
Expect(client.highestOpenedStream).Should(Equal(protocol.StreamID(5)))
|
||||
|
||||
var doRsp *http.Response
|
||||
var doErr error
|
||||
var doReturned bool
|
||||
go func() {
|
||||
doRsp, doErr = client.Do(req)
|
||||
doReturned = true
|
||||
}()
|
||||
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5)))
|
||||
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
||||
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
|
||||
rsp := &http.Response{
|
||||
Status: "418 I'm a teapot",
|
||||
StatusCode: 418,
|
||||
}
|
||||
client.responses[5] <- rsp
|
||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
Expect(doRsp).To(Equal(rsp))
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -107,15 +137,77 @@ var _ = Describe("Client", func() {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qClient.streams[3] = &mockStream{}
|
||||
client.requestWriter = newRequestWriter(qClient.streams[3])
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var doErr error
|
||||
var doReturned bool
|
||||
// the client.Do will block, because the encryption level is still set to Unencrypted
|
||||
go func() {
|
||||
_, doErr = client.Do(req)
|
||||
doReturned = true
|
||||
}()
|
||||
|
||||
Consistently(doReturned).Should(BeFalse())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
||||
Context("handling the header stream", func() {
|
||||
var h2framer *http2.Framer
|
||||
|
||||
BeforeEach(func() {
|
||||
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
|
||||
client.responses[23] = make(chan *http.Response)
|
||||
})
|
||||
|
||||
It("reads a response", func() {
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 23,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
go client.handleHeaderStream()
|
||||
var rsp *http.Response
|
||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
||||
Expect(rsp).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("errors if the H2 frame is not a HeadersFrame", func() {
|
||||
var handlerReturned bool
|
||||
go func() {
|
||||
client.handleHeaderStream()
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
||||
var rsp *http.Response
|
||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
||||
Expect(rsp).To(BeNil())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors if it can't read the HPACK encoded header fields", func() {
|
||||
var handlerReturned bool
|
||||
go func() {
|
||||
client.handleHeaderStream()
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 23,
|
||||
EndHeaders: true,
|
||||
BlockFragment: []byte("invalid HPACK data"),
|
||||
})
|
||||
|
||||
var rsp *http.Response
|
||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
||||
Expect(rsp).To(BeNil())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user