diff --git a/http3/response_writer.go b/http3/response_writer.go index 5cc329239..b7c79d50e 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -6,6 +6,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" @@ -15,6 +16,7 @@ import ( type responseWriter struct { conn quic.Connection + str quic.Stream bufferedStr *bufio.Writer buf []byte @@ -36,6 +38,7 @@ func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logge header: http.Header{}, buf: make([]byte, 16), conn: conn, + str: str, bufferedStr: bufio.NewWriter(str), logger: logger, } @@ -121,6 +124,14 @@ func (w *responseWriter) StreamCreator() StreamCreator { return w.conn } +func (w *responseWriter) SetReadDeadline(deadline time.Time) error { + return w.str.SetReadDeadline(deadline) +} + +func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { + return w.str.SetWriteDeadline(deadline) +} + // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 7644913a4..467201a45 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "time" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/utils" @@ -25,6 +26,8 @@ var _ = Describe("Response Writer", func() { strBuf = &bytes.Buffer{} str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() + str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() + str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() rw = newResponseWriter(str, nil, utils.DefaultLogger) }) @@ -156,4 +159,9 @@ var _ = Describe("Response Writer", func() { fields := decodeHeader(strBuf) Expect(fields).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) }) + + It(`is compatible with "net/http".ResponseController`, func() { + Expect(rw.SetReadDeadline(time.Now().Add(1 * time.Second))).To(BeNil()) + Expect(rw.SetWriteDeadline(time.Now().Add(1 * time.Second))).To(BeNil()) + }) }) diff --git a/integrationtests/self/go119_test.go b/integrationtests/self/go119_test.go new file mode 100644 index 000000000..cd9824dd5 --- /dev/null +++ b/integrationtests/self/go119_test.go @@ -0,0 +1,22 @@ +//go:build go1.19 && !go1.20 + +package self_test + +import ( + "errors" + "net/http" + "time" +) + +var ( + go120 = false + errNotSupported = errors.New("not supported") +) + +func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { + return errNotSupported +} + +func setWriteDeadline(w http.ResponseWriter, deadline time.Time) error { + return errNotSupported +} diff --git a/integrationtests/self/go120_test.go b/integrationtests/self/go120_test.go new file mode 100644 index 000000000..4ddf3c7c5 --- /dev/null +++ b/integrationtests/self/go120_test.go @@ -0,0 +1,22 @@ +//go:build go1.20 + +package self_test + +import ( + "net/http" + "time" +) + +var go120 = true + +func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { + rc := http.NewResponseController(w) + + return rc.SetReadDeadline(deadline) +} + +func setWriteDeadline(w http.ResponseWriter, deadline time.Time) error { + rc := http.NewResponseController(w) + + return rc.SetWriteDeadline(deadline) +} diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index db67c47e3..0adfff865 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "os" "strconv" "time" @@ -25,6 +26,17 @@ import ( "github.com/onsi/gomega/gbytes" ) +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +const deadlineDelay = 250 * time.Millisecond + var _ = Describe("HTTP tests", func() { var ( mux *http.ServeMux @@ -374,6 +386,60 @@ var _ = Describe("HTTP tests", func() { Expect(repl).To(Equal(data)) }) + It("supports read deadlines", func() { + if !go120 { + Skip("This test requires Go 1.20+") + } + + mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + + body, err := io.ReadAll(r.Body) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(body).To(ContainSubstring("aa")) + + w.Write([]byte("ok")) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(Equal("ok")) + }) + + It("supports write deadlines", func() { + if !go120 { + Skip("This test requires Go 1.20+") + } + + mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + + _, err = io.Copy(w, neverEnding('a')) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + + resp, err := client.Get("https://localhost:" + port + "/write-deadline") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(ContainSubstring("aa")) + }) + if version != protocol.VersionDraft29 { It("serves other QUIC connections", func() { tlsConf := testdata.GetTLSConfig()