http3: cancel reading on request stream if request processing fails (#4417)

This commit is contained in:
Marten Seemann
2024-04-09 16:34:00 -04:00
committed by GitHub
parent eb310a6db8
commit ee698b326f
2 changed files with 61 additions and 77 deletions

View File

@@ -161,8 +161,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
s.handleRequest(conn, str, qpackDecoder)
var req *http.Request
Eventually(requestChan).Should(Receive(&req))
Expect(req.Host).To(Equal("www.example.com"))
@@ -179,9 +180,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})
@@ -196,9 +197,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
@@ -218,9 +219,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
// status, date, content-type
@@ -239,8 +240,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
str.EXPECT().Close()
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(responseBuf.Bytes()).To(HaveLen(0))
@@ -258,15 +260,16 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
str.EXPECT().Close()
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
})
It("handles a aborting handler", func() {
It("handles an aborting handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
})
@@ -275,10 +278,10 @@ var _ = Describe("Server", func() {
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).To(MatchError(errPanicked))
s.handleRequest(conn, str, qpackDecoder)
Expect(responseBuf.Bytes()).To(HaveLen(0))
})
@@ -291,10 +294,10 @@ var _ = Describe("Server", func() {
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).To(MatchError(errPanicked))
s.handleRequest(conn, str, qpackDecoder)
Expect(responseBuf.Bytes()).To(HaveLen(0))
})
@@ -352,6 +355,7 @@ var _ = Describe("Server", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
@@ -377,6 +381,7 @@ var _ = Describe("Server", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
@@ -596,6 +601,7 @@ var _ = Describe("Server", func() {
setRequest(append(requestData, b...))
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
@@ -611,6 +617,7 @@ var _ = Describe("Server", func() {
testErr := errors.New("stream reset")
done := make(chan struct{})
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
@@ -638,23 +645,23 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("closes the connection when the first frame is not a HEADERS frame", func() {
It("rejects a request that has too large request headers", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
// use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest,
// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
// but the request will still end up larger than DefaultMaxHeaderBytes.
url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
Expect(err).ToNot(HaveOccurred())
setRequest(encodeRequest(req))
// str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
done := make(chan struct{})
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
@@ -678,9 +685,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close()
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
Eventually(handlerCalled).Should(BeClosed())
})
@@ -701,9 +708,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close()
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
Eventually(handlerCalled).Should(BeClosed())
})
})