Files
quic-go/h2quic/client_test.go
Marten Seemann fc4adb4775 improve error handling in the h2quic client for header stream handling
When the underlying QUIC stream is closed, the close error should be
returned. This always happens when receiving a CONNECTION_CLOSE from the
server.
Furthermore, this adds a missing break statement in the case when
receiving an invalid HTTP request.
2018-01-03 09:07:18 +07:00

507 lines
18 KiB
Go

package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream)
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func(done Done) {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
go client.RoundTrip(req)
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
close(done)
}, 2)
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
var doErr error
go func() {
_, doErr = client.RoundTrip(request)
}()
_, err = client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Eventually(func() error { return doErr }).Should(MatchError(testErr))
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, r)
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
return mhframe
}
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
fields := make(map[string]string)
for _, hf := range f.Fields {
fields[hf.Name] = hf.Value
}
return fields
}
BeforeEach(func() {
var err error
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("does a request", func(done Done) {
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
Eventually(func() map[protocol.StreamID]chan *http.Response { return client.responses }).Should(HaveKey(protocol.StreamID(5)))
rsp := &http.Response{
Status: "418 I'm a teapot",
StatusCode: 418,
}
Expect(client.responses[5]).ToNot(BeClosed())
Expect(client.headerErrored).ToNot(BeClosed())
client.responses[5] <- rsp
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).ToNot(HaveOccurred())
Expect(doRsp).To(Equal(rsp))
Expect(doRsp.Body).To(Equal(dataStream))
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
Expect(doRsp.Request).To(Equal(request))
close(done)
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
_, err := client.RoundTrip(request)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream}
session.blockOpenStreamSync = true
var doReturned bool
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
doReturned = true
}()
go client.handleHeaderStream()
Consistently(func() bool { return doReturned }).Should(BeFalse())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("quic http2: unsupported scheme"))
})
It("adds the port for request URLs without one", func(done Done) {
var err error
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
var doErr error
var doReturned bool
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
go func() {
_, doErr = client.RoundTrip(req)
doReturned = true
}()
Consistently(doReturned).Should(BeFalse())
Expect(doErr).ToNot(HaveOccurred())
close(done)
})
})
It("sets the EndStream header for requests without a body", func() {
go func() { client.RoundTrip(request) }()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
})
It("sets the EndStream header to false for requests with a body", func() {
request.Body = &mockBody{}
go func() { client.RoundTrip(request) }()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
})
Context("requests containing a Body", func() {
var requestBody []byte
var response *http.Response
BeforeEach(func() {
requestBody = []byte("request body")
body := &mockBody{}
body.SetData(requestBody)
request.Body = body
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
// fake a handshake
client.dialOnce.Do(func() {})
session.streamsToOpen = []quic.Stream{dataStream}
})
It("sends a request", func() {
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
defer GinkgoRecover()
doRsp, doErr = client.RoundTrip(request)
Expect(doErr).ToNot(HaveOccurred())
doReturned = true
}()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
client.responses[5] <- response
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
Expect(doRsp).To(Equal(response))
})
It("returns the error that occurred when reading the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).readErr = testErr
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).To(MatchError(testErr))
Expect(doRsp).To(BeNil())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when closing the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).closeErr = testErr
var doRsp *http.Response
var doErr error
var doReturned bool
go func() {
doRsp, doErr = client.RoundTrip(request)
doReturned = true
}()
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).To(MatchError(testErr))
Expect(doRsp).To(BeNil())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
})
It("adds the gzip header to requests", func(done Done) {
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty())
close(dataStream.unblockRead)
data := make([]byte, 6)
_, err := io.ReadFull(doRsp.Body, data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
close(done)
}, 2)
It("doesn't add gzip if the header disable it", func() {
client.opts.DisableCompression = true
var doErr error
go func() { _, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).ToNot(HaveKey("accept-encoding"))
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write([]byte("not gzipped"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
data := make([]byte, 11)
doRsp.Body.Read(data)
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("not gzipped")))
})
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
request.Header.Add("accept-encoding", "gzip")
var doRsp *http.Response
var doErr error
go func() { doRsp, doErr = client.RoundTrip(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
dataStream.dataToRead.Write([]byte("gzipped data"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
data := make([]byte, 12)
doRsp.Body.Read(data)
Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("gzipped data")))
})
})
Context("handling the header stream", func() {
var h2framer *http2.Framer
BeforeEach(func() {
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
client.responses[23] = make(chan *http.Response)
})
It("reads header values from a response", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
headerStream.dataToRead.Write(data)
go client.handleHeaderStream()
var rsp *http.Response
Eventually(client.responses[23]).Should(Receive(&rsp))
Expect(rsp).ToNot(BeNil())
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
Expect(rsp.Status).To(Equal("302 Found"))
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
})
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
})
It("errors if it can't read the HPACK encoded header fields", func() {
h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 23,
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})
})