forked from quic-go/quic-go
implement a basic request writer
This commit is contained in:
192
h2quic/request_writer.go
Normal file
192
h2quic/request_writer.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/lex/httplex"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type requestWriter struct {
|
||||
headerStream utils.Stream
|
||||
|
||||
henc *hpack.Encoder
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
}
|
||||
|
||||
const defaultUserAgent = "quic-go"
|
||||
|
||||
func newRequestWriter(headerStream utils.Stream) *requestWriter {
|
||||
rw := &requestWriter{
|
||||
headerStream: headerStream,
|
||||
}
|
||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
||||
return rw
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID) error {
|
||||
// TODO: add support for trailers
|
||||
// TODO: add support for gzip compression
|
||||
// TODO: write continuation frames, if the header frame is too long
|
||||
w.encodeHeaders(req, false, "", actualContentLength(req))
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: uint32(dataStreamID),
|
||||
EndHeaders: true,
|
||||
BlockFragment: w.hbuf.Bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
// the rest of this files is copied from http2.Transport
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
||||
w.hbuf.Reset()
|
||||
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httplex.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var path string
|
||||
if req.Method != "CONNECT" {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httplex.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httplex.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
w.writeHeader(":authority", host)
|
||||
w.writeHeader(":method", req.Method)
|
||||
if req.Method != "CONNECT" {
|
||||
w.writeHeader(":path", path)
|
||||
w.writeHeader(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if trailers != "" {
|
||||
w.writeHeader("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
switch lowKey {
|
||||
case "host", "content-length":
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
case "user-agent":
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, v := range vv {
|
||||
w.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
w.writeHeader("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
w.writeHeader("user-agent", defaultUserAgent)
|
||||
}
|
||||
return w.hbuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeader(name, value string) {
|
||||
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
||||
87
h2quic/request_writer_test.go
Normal file
87
h2quic/request_writer_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Request", func() {
|
||||
var (
|
||||
rw *requestWriter
|
||||
headerStream *mockStream
|
||||
decoder *hpack.Decoder
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
headerStream = &mockStream{}
|
||||
rw = newRequestWriter(headerStream)
|
||||
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
})
|
||||
|
||||
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
|
||||
framer := http2.NewFramer(nil, bytes.NewReader(p))
|
||||
frame, err := framer.ReadFrame()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headerFrame := frame.(*http2.HeadersFrame)
|
||||
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
values := make(map[string]string)
|
||||
for _, headerField := range fields {
|
||||
values[headerField.Name] = headerField.Value
|
||||
}
|
||||
return headerFrame, values
|
||||
}
|
||||
|
||||
It("writes a GET request", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 1337)
|
||||
headerFrame, headerFields := decode(headerStream.Bytes())
|
||||
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
|
||||
})
|
||||
|
||||
It("writes a POST request", func() {
|
||||
form := url.Values{}
|
||||
form.Add("foo", "bar")
|
||||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 5)
|
||||
_, headerFields := decode(headerStream.Bytes())
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||
Expect(headerFields).To(HaveKey("content-length"))
|
||||
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(contentLength).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("sends cookies", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookie1 := &http.Cookie{
|
||||
Name: "Cookie #1",
|
||||
Value: "Value #1",
|
||||
}
|
||||
cookie2 := &http.Cookie{
|
||||
Name: "Cookie #2",
|
||||
Value: "Value #2",
|
||||
}
|
||||
req.AddCookie(cookie1)
|
||||
req.AddCookie(cookie2)
|
||||
rw.WriteRequest(req, 11)
|
||||
_, headerFields := decode(headerStream.Bytes())
|
||||
Expect(headerFields).To(HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user