forked from quic-go/quic-go
don't close the stream if it was hijacked
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
// * for the server: the http.Request.Body
|
||||
// * for the client: the http.Response.Body
|
||||
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
|
||||
// When a stream is taken over, it's the caller's responsibility to close the stream.
|
||||
type HTTPStreamer interface {
|
||||
HTTPStream() Stream
|
||||
}
|
||||
@@ -36,6 +37,8 @@ type Hijacker interface {
|
||||
// The body of a http.Request or http.Response.
|
||||
type body struct {
|
||||
str quic.Stream
|
||||
|
||||
wasHijacked bool // set when HTTPStream is called
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -48,9 +51,14 @@ func newRequestBody(str Stream) *body {
|
||||
}
|
||||
|
||||
func (r *body) HTTPStream() Stream {
|
||||
r.wasHijacked = true
|
||||
return r.str
|
||||
}
|
||||
|
||||
func (r *body) wasStreamHijacked() bool {
|
||||
return r.wasHijacked
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
return r.str.Read(b)
|
||||
}
|
||||
|
||||
@@ -549,7 +549,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
||||
}
|
||||
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
req.Body = newRequestBody(newStream(str, onFrameError))
|
||||
body := newRequestBody(newStream(str, onFrameError))
|
||||
req.Body = body
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||
@@ -583,6 +584,10 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
||||
handler.ServeHTTP(r, req)
|
||||
}()
|
||||
|
||||
if body.wasStreamHijacked() {
|
||||
return requestError{err: errHijacked}
|
||||
}
|
||||
|
||||
if panicked {
|
||||
r.WriteHeader(500)
|
||||
} else {
|
||||
|
||||
@@ -159,7 +159,6 @@ var _ = Describe("Server", func() {
|
||||
|
||||
qpackDecoder = qpack.NewDecoder(nil)
|
||||
str = mockquic.NewMockStream(mockCtrl)
|
||||
|
||||
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
|
||||
@@ -480,7 +479,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
}
|
||||
|
||||
It("reset streams Other than the control stream and the QPACK streams", func() {
|
||||
It("reset streams other than the control stream and the QPACK streams", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
quicvarint.Write(buf, 1337)
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
@@ -623,9 +622,9 @@ var _ = Describe("Server", func() {
|
||||
AfterEach(func() { testDone <- struct{}{} })
|
||||
|
||||
It("cancels reading when client sends a body in GET request", func() {
|
||||
handlerCalled := make(chan struct{})
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(handlerCalled)
|
||||
handlerCalled = true
|
||||
})
|
||||
|
||||
requestData := encodeRequest(exampleGetRequest)
|
||||
@@ -644,6 +643,27 @@ var _ = Describe("Server", func() {
|
||||
Eventually(done).Should(BeClosed())
|
||||
hfs := decodeHeader(responseBuf)
|
||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
Expect(handlerCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() {
|
||||
handlerCalled := make(chan struct{})
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(handlerCalled)
|
||||
r.Body.(HTTPStreamer).HTTPStream()
|
||||
str.Write([]byte("foobar"))
|
||||
})
|
||||
|
||||
requestData := encodeRequest(exampleGetRequest)
|
||||
buf := &bytes.Buffer{}
|
||||
(&dataFrame{Length: 6}).Write(buf) // add a body
|
||||
buf.Write([]byte("foobar"))
|
||||
setRequest(append(requestData, buf.Bytes()...))
|
||||
str.EXPECT().Context().Return(reqContext)
|
||||
str.EXPECT().Write([]byte("foobar")).Return(6, nil)
|
||||
|
||||
s.handleConn(conn)
|
||||
Eventually(handlerCalled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors when the client sends a too large header frame", func() {
|
||||
|
||||
Reference in New Issue
Block a user