diff --git a/example/main.go b/example/main.go index cdc4ff03f..368b51844 100644 --- a/example/main.go +++ b/example/main.go @@ -2,9 +2,11 @@ package main import ( "bytes" + "errors" "flag" "fmt" "net/http" + "net/url" "os" "strconv" @@ -44,9 +46,6 @@ type responseWriter struct { header http.Header headerWritten bool - - bytesWritten int - contentLength int } func (w *responseWriter) Header() http.Header { @@ -59,8 +58,6 @@ func (w *responseWriter) WriteHeader(status int) { var headers bytes.Buffer enc := hpack.NewEncoder(&headers) enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - // enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(p))}) - // enc.WriteField(hpack.HeaderField{Name: "content-type", Value: http.DetectContentType(p)}) for k, v := range w.header { enc.WriteField(hpack.HeaderField{Name: k, Value: v[0]}) @@ -73,11 +70,10 @@ func (w *responseWriter) WriteHeader(status int) { EndHeaders: true, BlockFragment: headers.Bytes(), }) - - w.contentLength, _ = strconv.Atoi(w.header.Get("content-length")) } func (w *responseWriter) Write(p []byte) (int, error) { + fmt.Printf("%#v\n", w.header) if !w.headerWritten { w.WriteHeader(200) } @@ -90,17 +86,8 @@ func (w *responseWriter) Write(p []byte) (int, error) { return 0, fmt.Errorf("error creating data stream: %s\n", err.Error()) } } - - n, err := w.dataStream.Write(p) - w.bytesWritten += n - - if w.bytesWritten >= w.contentLength { - defer w.dataStream.Close() - } - - return n, err + return w.dataStream.Write(p) } - return 0, nil } @@ -126,18 +113,12 @@ func handleStream(session *quic.Session, headerStream utils.Stream) { continue } - headersMap := map[string]string{} - for _, h := range headers { - headersMap[h.Name] = h.Value - } - - fmt.Printf("Request: %s %s://%s%s on stream %d\n", headersMap[":method"], headersMap[":scheme"], headersMap[":authority"], headersMap[":path"], h2headersFrame.StreamID) - - req, err := http.NewRequest(headersMap[":method"], headersMap[":path"], nil) + req, err := requestFromHeaders(headers) if err != nil { fmt.Printf("invalid http2 frame: %s\n", err.Error()) continue } + fmt.Printf("Request: %#v\n", req) responseWriter := &responseWriter{ header: http.Header{}, @@ -146,7 +127,54 @@ func handleStream(session *quic.Session, headerStream utils.Stream) { session: session, } - go http.DefaultServeMux.ServeHTTP(responseWriter, req) + go func() { + http.DefaultServeMux.ServeHTTP(responseWriter, req) + if responseWriter.dataStream != nil { + responseWriter.dataStream.Close() + } + }() } }() } + +func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { + var path, authority, method string + httpHeaders := http.Header{} + + for _, h := range headers { + switch h.Name { + case ":path": + path = h.Value + case ":method": + method = h.Value + case ":authority": + authority = h.Value + default: + if !h.IsPseudo() { + httpHeaders.Add(h.Name, h.Value) + } + } + } + + if len(path) == 0 || len(authority) == 0 || len(method) == 0 { + return nil, errors.New(":path, :authority and :method must not be empty") + } + + u, err := url.Parse(path) + if err != nil { + return nil, err + } + + return &http.Request{ + Method: method, + URL: u, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Header: httpHeaders, + Body: nil, + // ContentLength: -1, + Host: authority, + RequestURI: path, + }, nil +}