diff --git a/http3/client.go b/http3/client.go index 501e3db57..61ded6246 100644 --- a/http3/client.go +++ b/http3/client.go @@ -7,6 +7,8 @@ import ( "io" "log/slog" "net/http" + "net/http/httptrace" + "net/textproto" "sync" "time" @@ -313,9 +315,37 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.St }() } - res, err := hstr.ReadResponse() - if err != nil { - return nil, err + var ( + res *http.Response + err error + ) + + // copy from net/http: support 1xx responses + trace := httptrace.ContextClientTrace(req.Context()) + num1xx := 0 // number of informational 1xx headers received + const max1xxResponses = 5 // arbitrary bound on number of informational responses + + for { + if res, err = hstr.ReadResponse(); err != nil { + return nil, err + } + resCode := res.StatusCode + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see https://github.com/golang/go/issues/26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + if is1xxNonTerminal { + num1xx++ + if num1xx > max1xxResponses { + return nil, errors.New("http: too many 1xx informational responses") + } + if trace != nil && trace.Got1xxResponse != nil { + if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil { + return nil, err + } + } + continue + } + break } connState := c.Connection.ConnectionState().TLS res.TLS = &connState diff --git a/http3/client_test.go b/http3/client_test.go index 7adc62e85..5c62ec6a1 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -7,6 +7,8 @@ import ( "errors" "io" "net/http" + "net/http/httptrace" + "net/textproto" "sync" "time" @@ -26,6 +28,10 @@ func encodeResponse(status int) []byte { rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() rw := newResponseWriter(newStream(rstr, nil), nil, false, nil) + if status == http.StatusEarlyHints { + rw.header.Add("Link", "; rel=preload; as=style") + rw.header.Add("Link", "; rel=preload; as=script") + } rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -858,5 +864,75 @@ var _ = Describe("Client", func() { Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) }) }) + + Context("1xx status code", func() { + It("continues to read next header if code is 103", func() { + var ( + cnt int + status int + hdr textproto.MIMEHeader + ) + header1 := "; rel=preload; as=style" + header2 := "; rel=preload; as=script" + ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + cnt++ + status = code + hdr = header + return nil + }, + }) + req := req.WithContext(ctx) + rspBuf := bytes.NewBuffer(encodeResponse(103)) + gomock.InOrder( + conn.EXPECT().HandshakeComplete().Return(handshakeChan), + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), + ) + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Close() + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := cl.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Proto).To(Equal("HTTP/3.0")) + Expect(rsp.ProtoMajor).To(Equal(3)) + Expect(rsp.StatusCode).To(Equal(200)) + Expect(rsp.Header).To(HaveKeyWithValue("Link", []string{header1, header2})) + Expect(status).To(Equal(103)) + Expect(cnt).To(Equal(1)) + Expect(hdr).To(HaveKeyWithValue("Link", []string{header1, header2})) + Expect(rsp.Request).ToNot(BeNil()) + }) + + It("doesn't continue to read next header if code is a terminal status", func() { + cnt := 0 + status := 0 + ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + cnt++ + status = code + return nil + }, + }) + req := req.WithContext(ctx) + rspBuf := bytes.NewBuffer(encodeResponse(101)) + gomock.InOrder( + conn.EXPECT().HandshakeComplete().Return(handshakeChan), + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), + ) + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Close() + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := cl.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Proto).To(Equal("HTTP/3.0")) + Expect(rsp.ProtoMajor).To(Equal(3)) + Expect(rsp.StatusCode).To(Equal(101)) + Expect(status).To(Equal(0)) + Expect(cnt).To(Equal(0)) + Expect(rsp.Request).ToNot(BeNil()) + }) + }) }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 44beaf958..cf0cfb6f2 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -11,6 +11,8 @@ import ( "io" "net" "net/http" + "net/http/httptrace" + "net/textproto" "os" "strconv" "sync/atomic" @@ -638,6 +640,81 @@ var _ = Describe("HTTP tests", func() { Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) }) + It("processes 1xx response", func() { + header1 := "; rel=preload; as=style" + header2 := "; rel=preload; as=script" + data := "1xx-test-data" + mux.HandleFunc("/103-early-data", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Header().Add("Link", header1) + w.Header().Add("Link", header2) + w.WriteHeader(http.StatusEarlyHints) + n, err := w.Write([]byte(data)) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(len(data))) + w.WriteHeader(http.StatusOK) + }) + + var ( + cnt int + status int + hdr textproto.MIMEHeader + ) + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + hdr = header + status = code + cnt++ + return nil + }, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/103-early-data", port), nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(data)) + Expect(status).To(Equal(http.StatusEarlyHints)) + Expect(hdr).To(HaveKeyWithValue("Link", []string{header1, header2})) + Expect(cnt).To(Equal(1)) + Expect(resp.Header).To(HaveKeyWithValue("Link", []string{header1, header2})) + Expect(resp.Body.Close()).To(Succeed()) + }) + + It("processes 1xx terminal response", func() { + mux.HandleFunc("/101-switch-protocols", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Header().Add("Connection", "upgrade") + w.Header().Add("Upgrade", "proto") + w.WriteHeader(http.StatusSwitchingProtocols) + }) + + var ( + cnt int + status int + ) + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + status = code + cnt++ + return nil + }, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/101-switch-protocols", port), nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols)) + Expect(resp.Header).To(HaveKeyWithValue("Connection", []string{"upgrade"})) + Expect(resp.Header).To(HaveKeyWithValue("Upgrade", []string{"proto"})) + Expect(status).To(Equal(0)) + Expect(cnt).To(Equal(0)) + }) + Context("0-RTT", func() { runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) { var num0RTTPackets atomic.Uint32