forked from quic-go/quic-go
http3: cancel reading on request stream if request processing fails (#4417)
This commit is contained in:
@@ -30,7 +30,6 @@ var (
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
|
||||
return quic.ListenAddrEarly(addr, tlsConf, config)
|
||||
}
|
||||
errPanicked = errors.New("panicked")
|
||||
)
|
||||
|
||||
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
|
||||
@@ -126,20 +125,6 @@ var ServerContextKey = &contextKey{"http3-server"}
|
||||
// than its string representation.
|
||||
var RemoteAddrContextKey = &contextKey{"remote-addr"}
|
||||
|
||||
type requestError struct {
|
||||
err error
|
||||
streamErr ErrCode
|
||||
connErr ErrCode
|
||||
}
|
||||
|
||||
func newStreamError(code ErrCode, err error) requestError {
|
||||
return requestError{err: err, streamErr: code}
|
||||
}
|
||||
|
||||
func newConnError(code ErrCode, err error) requestError {
|
||||
return requestError{err: err, connErr: code}
|
||||
}
|
||||
|
||||
// listenerInfo contains info about specific listener added with addListener
|
||||
type listenerInfo struct {
|
||||
port int // 0 means that no info about port is available
|
||||
@@ -476,29 +461,7 @@ func (s *Server) handleConn(conn quic.Connection) error {
|
||||
}
|
||||
return fmt.Errorf("accepting stream failed: %w", err)
|
||||
}
|
||||
go func() {
|
||||
rerr := s.handleRequest(hconn, str, decoder, func(e ErrCode) {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(e), "")
|
||||
})
|
||||
if rerr.err == errHijacked {
|
||||
return
|
||||
}
|
||||
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
||||
s.logger.Debugf("Handling request failed: %s", err)
|
||||
if rerr.streamErr != 0 {
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 {
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return
|
||||
}
|
||||
str.Close()
|
||||
}()
|
||||
go s.handleRequest(hconn, str, decoder)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,7 +472,7 @@ func (s *Server) maxHeaderBytes() uint64 {
|
||||
return uint64(s.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder, closeConnection func(ErrCode)) requestError {
|
||||
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder) {
|
||||
var ufh unknownFrameHandlerFunc
|
||||
if s.StreamHijacker != nil {
|
||||
ufh = func(ft FrameType, e error) (processed bool, err error) {
|
||||
@@ -523,30 +486,39 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
|
||||
}
|
||||
frame, err := parseNextFrame(str, ufh)
|
||||
if err != nil {
|
||||
if err == errHijacked {
|
||||
return requestError{err: errHijacked}
|
||||
if !errors.Is(err, errHijacked) {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
}
|
||||
return newStreamError(ErrCodeRequestIncomplete, err)
|
||||
return
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
|
||||
return
|
||||
}
|
||||
if hf.Length > s.maxHeaderBytes() {
|
||||
return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
return
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return newStreamError(ErrCodeRequestIncomplete, err)
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
return
|
||||
}
|
||||
hfs, err := decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return newConnError(ErrCodeGeneralProtocolError, err)
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "expected first frame to be a HEADERS frame")
|
||||
return
|
||||
}
|
||||
req, err := requestFromHeaders(hfs)
|
||||
if err != nil {
|
||||
return newStreamError(ErrCodeMessageError, err)
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
return
|
||||
}
|
||||
|
||||
connState := conn.ConnectionState().TLS
|
||||
@@ -556,6 +528,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
|
||||
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength)
|
||||
} else {
|
||||
@@ -609,7 +582,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
|
||||
}()
|
||||
|
||||
if body.wasStreamHijacked() {
|
||||
return requestError{err: errHijacked}
|
||||
return
|
||||
}
|
||||
|
||||
// only write response when there is no panic
|
||||
@@ -622,14 +595,18 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
|
||||
}
|
||||
r.Flush()
|
||||
}
|
||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
|
||||
|
||||
// abort the stream when there is a panic
|
||||
if panicked {
|
||||
return newStreamError(ErrCodeInternalError, errPanicked)
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
|
||||
return
|
||||
}
|
||||
return requestError{}
|
||||
|
||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
|
||||
|
||||
str.Close()
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
|
||||
Reference in New Issue
Block a user