http3: cancel reading on request stream if request processing fails (#4417)

This commit is contained in:
Marten Seemann
2024-04-09 16:34:00 -04:00
committed by GitHub
parent eb310a6db8
commit ee698b326f
2 changed files with 61 additions and 77 deletions

View File

@@ -30,7 +30,6 @@ var (
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
return quic.ListenAddrEarly(addr, tlsConf, config)
}
errPanicked = errors.New("panicked")
)
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
@@ -126,20 +125,6 @@ var ServerContextKey = &contextKey{"http3-server"}
// than its string representation.
var RemoteAddrContextKey = &contextKey{"remote-addr"}
type requestError struct {
err error
streamErr ErrCode
connErr ErrCode
}
func newStreamError(code ErrCode, err error) requestError {
return requestError{err: err, streamErr: code}
}
func newConnError(code ErrCode, err error) requestError {
return requestError{err: err, connErr: code}
}
// listenerInfo contains info about specific listener added with addListener
type listenerInfo struct {
port int // 0 means that no info about port is available
@@ -476,29 +461,7 @@ func (s *Server) handleConn(conn quic.Connection) error {
}
return fmt.Errorf("accepting stream failed: %w", err)
}
go func() {
rerr := s.handleRequest(hconn, str, decoder, func(e ErrCode) {
conn.CloseWithError(quic.ApplicationErrorCode(e), "")
})
if rerr.err == errHijacked {
return
}
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
s.logger.Debugf("Handling request failed: %s", err)
if rerr.streamErr != 0 {
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
}
if rerr.connErr != 0 {
var reason string
if rerr.err != nil {
reason = rerr.err.Error()
}
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return
}
str.Close()
}()
go s.handleRequest(hconn, str, decoder)
}
}
@@ -509,7 +472,7 @@ func (s *Server) maxHeaderBytes() uint64 {
return uint64(s.MaxHeaderBytes)
}
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder, closeConnection func(ErrCode)) requestError {
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder) {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType, e error) (processed bool, err error) {
@@ -523,30 +486,39 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
frame, err := parseNextFrame(str, ufh)
if err != nil {
if err == errHijacked {
return requestError{err: errHijacked}
if !errors.Is(err, errHijacked) {
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
}
return newStreamError(ErrCodeRequestIncomplete, err)
return
}
hf, ok := frame.(*headersFrame)
if !ok {
return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
return
}
if hf.Length > s.maxHeaderBytes() {
return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
return
}
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return newStreamError(ErrCodeRequestIncomplete, err)
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
return
}
hfs, err := decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
return newConnError(ErrCodeGeneralProtocolError, err)
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "expected first frame to be a HEADERS frame")
return
}
req, err := requestFromHeaders(hfs)
if err != nil {
return newStreamError(ErrCodeMessageError, err)
str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
return
}
connState := conn.ConnectionState().TLS
@@ -556,6 +528,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength)
} else {
@@ -609,7 +582,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}()
if body.wasStreamHijacked() {
return requestError{err: errHijacked}
return
}
// only write response when there is no panic
@@ -622,14 +595,18 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
// abort the stream when there is a panic
if panicked {
return newStreamError(ErrCodeInternalError, errPanicked)
str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
return
}
return requestError{}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.Close()
}
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.