diff --git a/http3/client_test.go b/http3/client_test.go index 3cd09f94..c4ba168c 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -26,7 +26,7 @@ func encodeResponse(status int) []byte { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -738,7 +738,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) @@ -764,7 +764,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) diff --git a/http3/response_writer.go b/http3/response_writer.go index 46d8fff0..ba8b34db 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -1,7 +1,6 @@ package http3 import ( - "bufio" "bytes" "fmt" "net/http" @@ -21,10 +20,9 @@ const frameHeaderLen = 16 // headerWriter wraps the stream, so that the first Write call flushes the header to the stream type headerWriter struct { - str quic.Stream - header http.Header - status int // status code passed to WriteHeader - written bool + str quic.Stream + header http.Header + status int // status code passed to WriteHeader logger utils.Logger } @@ -33,11 +31,15 @@ type headerWriter struct { func (hw *headerWriter) writeHeader() error { var headers bytes.Buffer enc := qpack.NewEncoder(&headers) - enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) + if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}); err != nil { + return err + } for k, v := range hw.header { for index := range v { - enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil { + return err + } } } @@ -50,27 +52,22 @@ func (hw *headerWriter) writeHeader() error { return err } -// first Write will trigger flushing header -func (hw *headerWriter) Write(p []byte) (int, error) { - if !hw.written { - if err := hw.writeHeader(); err != nil { - return 0, err - } - hw.written = true - } - return hw.str.Write(p) -} +const maxSmallResponseSize = 4096 type responseWriter struct { *headerWriter - conn Connection - bufferedStr *bufio.Writer - buf []byte + conn Connection + buf []byte - contentLen int64 // if handler set valid Content-Length header - numWritten int64 // bytes written - headerWritten bool - isHead bool + // for responses smaller than maxSmallResponseSize, we buffer calls to Write, + // and automatically add the Content-Length header + smallResponseBuf []byte + + contentLen int64 // if handler set valid Content-Length header + numWritten int64 // bytes written + headerComplete bool // set once WriteHeader is called with a status code >= 200 + headerWritten bool // set once the response header has been serialized to the stream + isHead bool } var ( @@ -79,7 +76,7 @@ var ( _ Hijacker = &responseWriter{} ) -func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *responseWriter { +func newResponseWriter(str quic.Stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter { hw := &headerWriter{ str: str, header: http.Header{}, @@ -89,7 +86,7 @@ func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *r headerWriter: hw, buf: make([]byte, frameHeaderLen), conn: conn, - bufferedStr: bufio.NewWriter(hw), + isHead: isHead, } } @@ -98,7 +95,7 @@ func (w *responseWriter) Header() http.Header { } func (w *responseWriter) WriteHeader(status int) { - if w.headerWritten { + if w.headerComplete { return } @@ -106,37 +103,38 @@ func (w *responseWriter) WriteHeader(status int) { if status < 100 || status > 999 { panic(fmt.Sprintf("invalid WriteHeader code %v", status)) } - - if status >= 200 { - w.headerWritten = true - // Add Date header. - // This is what the standard library does. - // Can be disabled by setting the Date header to nil. - if _, ok := w.header["Date"]; !ok { - w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) - } - // Content-Length checking - // use ParseUint instead of ParseInt, as negative values are invalid - if clen := w.header.Get("Content-Length"); clen != "" { - if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { - w.contentLen = int64(cl) - } else { - // emit a warning for malformed Content-Length and remove it - w.logger.Errorf("Malformed Content-Length %s", clen) - w.header.Del("Content-Length") - } - } - } w.status = status - if !w.headerWritten { + // immediately write 1xx headers + if status < 200 { w.writeHeader() + return + } + + // We're done with headers once we write a status >= 200. + w.headerComplete = true + // Add Date header. + // This is what the standard library does. + // Can be disabled by setting the Date header to nil. + if _, ok := w.header["Date"]; !ok { + w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + } + // Content-Length checking + // use ParseUint instead of ParseInt, as negative values are invalid + if clen := w.header.Get("Content-Length"); clen != "" { + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + w.contentLen = int64(cl) + } else { + // emit a warning for malformed Content-Length and remove it + w.logger.Errorf("Malformed Content-Length %s", clen) + w.header.Del("Content-Length") + } } } func (w *responseWriter) Write(p []byte) (int, error) { bodyAllowed := bodyAllowedForStatus(w.status) - if !w.headerWritten { + if !w.headerComplete { // If body is not allowed, we don't need to (and we can't) sniff the content type. if bodyAllowed { // If no content type, apply sniffing algorithm to body. @@ -167,27 +165,58 @@ func (w *responseWriter) Write(p []byte) (int, error) { return len(p), nil } - df := &dataFrame{Length: uint64(len(p))} + if !w.headerWritten { + // Buffer small responses. + // This allows us to automatically set the Content-Length field. + if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize { + w.smallResponseBuf = append(w.smallResponseBuf, p...) + return len(p), nil + } + } + return w.doWrite(p) +} + +func (w *responseWriter) doWrite(p []byte) (int, error) { + if !w.headerWritten { + if err := w.writeHeader(); err != nil { + return 0, maybeReplaceError(err) + } + w.headerWritten = true + } + + l := uint64(len(w.smallResponseBuf) + len(p)) + if l == 0 { + return 0, nil + } + df := &dataFrame{Length: l} w.buf = w.buf[:0] w.buf = df.Append(w.buf) - if _, err := w.bufferedStr.Write(w.buf); err != nil { + if _, err := w.str.Write(w.buf); err != nil { return 0, maybeReplaceError(err) } - n, err := w.bufferedStr.Write(p) - return n, maybeReplaceError(err) + if len(w.smallResponseBuf) > 0 { + if _, err := w.str.Write(w.smallResponseBuf); err != nil { + return 0, maybeReplaceError(err) + } + w.smallResponseBuf = nil + } + var n int + if len(p) > 0 { + var err error + n, err = w.str.Write(p) + if err != nil { + return n, maybeReplaceError(err) + } + } + return n, nil } func (w *responseWriter) FlushError() error { - if !w.headerWritten { + if !w.headerComplete { w.WriteHeader(http.StatusOK) } - if !w.written { - if err := w.writeHeader(); err != nil { - return maybeReplaceError(err) - } - w.written = true - } - return w.bufferedStr.Flush() + _, err := w.doWrite(nil) + return err } func (w *responseWriter) Flush() { diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index ed6c1d28..3254b6e2 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -28,7 +28,7 @@ var _ = Describe("Response Writer", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() - rw = newResponseWriter(str, nil, utils.DefaultLogger) + rw = newResponseWriter(str, nil, false, utils.DefaultLogger) }) decodeHeader := func(str io.Reader) map[string][]string { diff --git a/http3/server.go b/http3/server.go index 50b5ba40..86666297 100644 --- a/http3/server.go +++ b/http3/server.go @@ -554,10 +554,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack } } req = req.WithContext(ctx) - r := newResponseWriter(str, conn, s.logger) - if req.Method == http.MethodHead { - r.isHead = true - } + r := newResponseWriter(str, conn, req.Method == http.MethodHead, s.logger) handler := s.Handler if handler == nil { handler = http.DefaultServeMux @@ -588,7 +585,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack // only write response when there is no panic if !panicked { // response not written to the client yet, set Content-Length - if !r.written { + if !r.headerWritten { if _, haveCL := r.header["Content-Length"]; !haveCL { r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) } diff --git a/http3/server_test.go b/http3/server_test.go index 868a6b06..dc41c5e8 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -228,12 +228,12 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveLen(3)) }) - It("response to HEAD request should not have body", func() { + It("ignores calls to Write for responses to HEAD requests", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("foobar")) }) - headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) + headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) Expect(err).ToNot(HaveOccurred()) responseBuf := &bytes.Buffer{} setRequest(encodeRequest(headRequest)) @@ -245,7 +245,7 @@ var _ = Describe("Server", func() { s.handleRequest(conn, str, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) - Expect(responseBuf.Bytes()).To(HaveLen(0)) + Expect(responseBuf.Bytes()).To(BeEmpty()) }) It("response to HEAD request should also do content sniffing", func() { @@ -253,7 +253,7 @@ var _ = Describe("Server", func() { w.Write([]byte("")) }) - headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) + headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) Expect(err).ToNot(HaveOccurred()) responseBuf := &bytes.Buffer{} setRequest(encodeRequest(headRequest)) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 327d61d7..41bdc20f 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -136,13 +136,14 @@ var _ = Describe("HTTP tests", func() { It("sets content-length for small response", func() { mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() - w.Write([]byte("foobar")) + w.Write([]byte("foo")) + w.Write([]byte("bar")) }) resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) - Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) + Expect(resp.Header.Get("Content-Length")).To(Equal("6")) }) It("detects stream errors when server panics when writing response", func() {