don't close the stream if it was hijacked

This commit is contained in:
Marten Seemann
2022-05-30 20:03:21 +02:00
parent 6fa7494c2f
commit 3eed9d1104
4 changed files with 46 additions and 8 deletions

View File

@@ -12,6 +12,7 @@ import (
// * for the server: the http.Request.Body
// * for the client: the http.Response.Body
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
// When a stream is taken over, it's the caller's responsibility to close the stream.
type HTTPStreamer interface {
HTTPStream() Stream
}
@@ -36,6 +37,8 @@ type Hijacker interface {
// The body of a http.Request or http.Response.
type body struct {
str quic.Stream
wasHijacked bool // set when HTTPStream is called
}
var (
@@ -48,9 +51,14 @@ func newRequestBody(str Stream) *body {
}
func (r *body) HTTPStream() Stream {
r.wasHijacked = true
return r.str
}
func (r *body) wasStreamHijacked() bool {
return r.wasHijacked
}
func (r *body) Read(b []byte) (int, error) {
return r.str.Read(b)
}

View File

@@ -549,7 +549,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
}
req.RemoteAddr = conn.RemoteAddr().String()
req.Body = newRequestBody(newStream(str, onFrameError))
body := newRequestBody(newStream(str, onFrameError))
req.Body = body
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
@@ -583,6 +584,10 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
handler.ServeHTTP(r, req)
}()
if body.wasStreamHijacked() {
return requestError{err: errHijacked}
}
if panicked {
r.WriteHeader(500)
} else {

View File

@@ -159,7 +159,6 @@ var _ = Describe("Server", func() {
qpackDecoder = qpack.NewDecoder(nil)
str = mockquic.NewMockStream(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
@@ -480,7 +479,7 @@ var _ = Describe("Server", func() {
})
}
It("reset streams Other than the control stream and the QPACK streams", func() {
It("reset streams other than the control stream and the QPACK streams", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
str := mockquic.NewMockStream(mockCtrl)
@@ -623,9 +622,9 @@ var _ = Describe("Server", func() {
AfterEach(func() { testDone <- struct{}{} })
It("cancels reading when client sends a body in GET request", func() {
handlerCalled := make(chan struct{})
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
handlerCalled = true
})
requestData := encodeRequest(exampleGetRequest)
@@ -644,6 +643,27 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(handlerCalled).To(BeTrue())
})
It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(handlerCalled)
r.Body.(HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))
})
requestData := encodeRequest(exampleGetRequest)
buf := &bytes.Buffer{}
(&dataFrame{Length: 6}).Write(buf) // add a body
buf.Write([]byte("foobar"))
setRequest(append(requestData, buf.Bytes()...))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write([]byte("foobar")).Return(6, nil)
s.handleConn(conn)
Eventually(handlerCalled).Should(BeClosed())
})
It("errors when the client sends a too large header frame", func() {