diff --git a/h2quic/client.go b/h2quic/client.go index 5f12c9f2..e86e7659 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -2,15 +2,19 @@ package h2quic import ( "errors" + "fmt" "net" "net/http" "strings" "sync" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/utils" ) @@ -22,7 +26,7 @@ type quicClient interface { // Client is a HTTP2 client doing QUIC requests type Client struct { - mutex sync.Mutex + mutex sync.RWMutex cryptoChangedCond sync.Cond hostname string @@ -30,6 +34,7 @@ type Client struct { client quicClient headerStream utils.Stream + headerErr *qerr.QuicError highestOpenedStream protocol.StreamID requestWriter *requestWriter @@ -81,9 +86,57 @@ func (c *Client) versionNegotiateCallback() error { return err } c.requestWriter = newRequestWriter(c.headerStream) + go c.handleHeaderStream() return nil } +func (c *Client) handleHeaderStream() { + decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) + h2framer := http2.NewFramer(nil, c.headerStream) + + var lastStream protocol.StreamID + + for { + frame, err := h2framer.ReadFrame() + if err != nil { + c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame") + break + } + lastStream = protocol.StreamID(frame.Header().StreamID) + hframe, ok := frame.(*http2.HeadersFrame) + if !ok { + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") + break + } + mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} + mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) + if err != nil { + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") + break + } + + c.mutex.RLock() + headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] + c.mutex.RUnlock() + if !ok { + c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) + break + } + + rsp := &http.Response{} + // TODO: fill in the right values + headerChan <- rsp + } + + // stop all running request + utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) + c.mutex.Lock() + for _, responseChan := range c.responses { + responseChan <- nil + } + c.mutex.Unlock() +} + // Do executes a request and returns a response func (c *Client) Do(req *http.Request) (*http.Response, error) { // TODO: add port to address, if it doesn't have one @@ -101,6 +154,9 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { for c.encryptionLevel != protocol.EncryptionForwardSecure { c.cryptoChangedCond.Wait() } + + hdrChan := make(chan *http.Response) + c.responses[dataStreamID] = hdrChan _, err := c.client.OpenStream(dataStreamID) if err != nil { return nil, err @@ -111,9 +167,20 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { } c.mutex.Unlock() - // TODO: get the response + var rsp *http.Response + select { + case rsp = <-hdrChan: + c.mutex.Lock() + delete(c.responses, dataStreamID) + c.mutex.Unlock() + } - return nil, nil + // if an error occured on the header stream + if rsp == nil { + return nil, c.headerErr + } + + return rsp, nil } // copied from net/transport.go diff --git a/h2quic/client_test.go b/h2quic/client_test.go index f078663c..061ee21f 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -3,7 +3,10 @@ package h2quic import ( "net/http" + "golang.org/x/net/http2" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" @@ -35,8 +38,11 @@ func newMockQuicClient() *mockQuicClient { var _ quicClient = &mockQuicClient{} var _ = Describe("Client", func() { - var client *Client - var qClient *mockQuicClient + var ( + client *Client + qClient *mockQuicClient + headerStream *mockStream + ) BeforeEach(func() { var err error @@ -46,6 +52,11 @@ var _ = Describe("Client", func() { Expect(client.hostname).To(Equal(hostname)) qClient = newMockQuicClient() client.client = qClient + + headerStream = &mockStream{} + qClient.streams[3] = headerStream + client.headerStream = headerStream + client.requestWriter = newRequestWriter(headerStream) }) It("adds the port to the hostname, if none is given", func() { @@ -56,7 +67,11 @@ var _ = Describe("Client", func() { }) It("opens the header stream only after the version has been negotiated", func() { + // delete the headerStream openend in the BeforeEach + client.headerStream = nil + delete(qClient.streams, 3) Expect(client.headerStream).To(BeNil()) // header stream not yet opened + // now start the actual test err := client.versionNegotiateCallback() Expect(err).ToNot(HaveOccurred()) Expect(client.headerStream).ToNot(BeNil()) @@ -73,18 +88,33 @@ var _ = Describe("Client", func() { Context("Doing requests", func() { BeforeEach(func() { - qClient.streams[3] = &mockStream{} - client.requestWriter = newRequestWriter(qClient.streams[3]) + client.encryptionLevel = protocol.EncryptionForwardSecure }) It("does a request", func(done Done) { - client.encryptionLevel = protocol.EncryptionForwardSecure req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) - go client.Do(req) - Eventually(func() []byte { return qClient.streams[3].dataWritten.Bytes() }).ShouldNot(BeEmpty()) - Expect(client.highestOpenedStream).Should(Equal(protocol.StreamID(5))) + + var doRsp *http.Response + var doErr error + var doReturned bool + go func() { + doRsp, doErr = client.Do(req) + doReturned = true + }() + + Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) + Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5))) Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5))) + Expect(client.responses).To(HaveKey(protocol.StreamID(5))) + rsp := &http.Response{ + Status: "418 I'm a teapot", + StatusCode: 418, + } + client.responses[5] <- rsp + Eventually(func() bool { return doReturned }).Should(BeTrue()) + Expect(doErr).ToNot(HaveOccurred()) + Expect(doRsp).To(Equal(rsp)) close(done) }) @@ -107,15 +137,77 @@ var _ = Describe("Client", func() { var err error client, err = NewClient("quic.clemente.io") Expect(err).ToNot(HaveOccurred()) - qClient.streams[3] = &mockStream{} - client.requestWriter = newRequestWriter(qClient.streams[3]) - client.encryptionLevel = protocol.EncryptionForwardSecure req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - _, err = client.Do(req) - Expect(err).ToNot(HaveOccurred()) + + var doErr error + var doReturned bool + // the client.Do will block, because the encryption level is still set to Unencrypted + go func() { + _, doErr = client.Do(req) + doReturned = true + }() + + Consistently(doReturned).Should(BeFalse()) + Expect(doErr).ToNot(HaveOccurred()) close(done) }) }) + + Context("handling the header stream", func() { + var h2framer *http2.Framer + + BeforeEach(func() { + h2framer = http2.NewFramer(&headerStream.dataToRead, nil) + client.responses[23] = make(chan *http.Response) + }) + + It("reads a response", func() { + headerStream.dataToRead.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 23, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + go client.handleHeaderStream() + var rsp *http.Response + Eventually(client.responses[23]).Should(Receive(&rsp)) + Expect(rsp).ToNot(BeNil()) + }) + + It("errors if the H2 frame is not a HeadersFrame", func() { + var handlerReturned bool + go func() { + client.handleHeaderStream() + handlerReturned = true + }() + + h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0}) + var rsp *http.Response + Eventually(client.responses[23]).Should(Receive(&rsp)) + Expect(rsp).To(BeNil()) + Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame"))) + Eventually(func() bool { return handlerReturned }).Should(BeTrue()) + }) + + It("errors if it can't read the HPACK encoded header fields", func() { + var handlerReturned bool + go func() { + client.handleHeaderStream() + handlerReturned = true + }() + + h2framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 23, + EndHeaders: true, + BlockFragment: []byte("invalid HPACK data"), + }) + + var rsp *http.Response + Eventually(client.responses[23]).Should(Receive(&rsp)) + Expect(rsp).To(BeNil()) + Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields"))) + Eventually(func() bool { return handlerReturned }).Should(BeTrue()) + }) + }) }) })