package http3 import ( "context" "errors" "fmt" "io" "log/slog" "net" "net/http" "net/http/httptrace" "sync" "sync/atomic" "time" "git.geeks-team.ru/gr1ffon/quic-go" "git.geeks-team.ru/gr1ffon/quic-go/http3/qlog" "git.geeks-team.ru/gr1ffon/quic-go/qlogwriter" "git.geeks-team.ru/gr1ffon/quic-go/quicvarint" "github.com/quic-go/qpack" ) const maxQuarterStreamID = 1<<60 - 1 var errGoAway = errors.New("connection in graceful shutdown") // invalidStreamID is a stream ID that is invalid. The first valid stream ID in QUIC is 0. const invalidStreamID = quic.StreamID(-1) // Conn is an HTTP/3 connection. // It has all methods from the quic.Conn expect for AcceptStream, AcceptUniStream, // SendDatagram and ReceiveDatagram. type Conn struct { Conn *quic.Conn ctx context.Context isServer bool logger *slog.Logger enableDatagrams bool decoder *qpack.Decoder streamMx sync.Mutex Streams map[quic.StreamID]*stateTrackingStream LastStreamID quic.StreamID maxStreamID quic.StreamID settings *Settings receivedSettings chan struct{} idleTimeout time.Duration idleTimer *time.Timer qlogger qlogwriter.Recorder } func newConnection( ctx context.Context, quicConn *quic.Conn, enableDatagrams bool, isServer bool, logger *slog.Logger, idleTimeout time.Duration, ) *Conn { var qlogger qlogwriter.Recorder if qlogTrace := quicConn.QlogTrace(); qlogTrace != nil && qlogTrace.SupportsSchemas(qlog.EventSchema) { qlogger = qlogTrace.AddProducer() } c := &Conn{ ctx: ctx, Conn: quicConn, isServer: isServer, logger: logger, idleTimeout: idleTimeout, enableDatagrams: enableDatagrams, decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), receivedSettings: make(chan struct{}), Streams: make(map[quic.StreamID]*stateTrackingStream), maxStreamID: invalidStreamID, LastStreamID: invalidStreamID, qlogger: qlogger, } if idleTimeout > 0 { c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) } return c } func (c *Conn) OpenStream() (*quic.Stream, error) { return c.Conn.OpenStream() } func (c *Conn) OpenStreamSync(ctx context.Context) (*quic.Stream, error) { return c.Conn.OpenStreamSync(ctx) } func (c *Conn) OpenUniStream() (*quic.SendStream, error) { return c.Conn.OpenUniStream() } func (c *Conn) OpenUniStreamSync(ctx context.Context) (*quic.SendStream, error) { return c.Conn.OpenUniStreamSync(ctx) } func (c *Conn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } func (c *Conn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() } func (c *Conn) HandshakeComplete() <-chan struct{} { return c.Conn.HandshakeComplete() } func (c *Conn) ConnectionState() quic.ConnectionState { return c.Conn.ConnectionState() } func (c *Conn) onIdleTimer() { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout") } func (c *Conn) clearStream(id quic.StreamID) { c.streamMx.Lock() defer c.streamMx.Unlock() delete(c.Streams, id) if c.idleTimeout > 0 && len(c.Streams) == 0 { c.idleTimer.Reset(c.idleTimeout) } // The server is performing a graceful shutdown. // If no more streams are remaining, close the connection. if c.maxStreamID != invalidStreamID { if len(c.Streams) == 0 { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } } } func (c *Conn) openRequestStream( ctx context.Context, requestWriter *requestWriter, reqDone chan<- struct{}, disableCompression bool, maxHeaderBytes uint64, ) (*RequestStream, error) { c.streamMx.Lock() maxStreamID := c.maxStreamID var nextStreamID quic.StreamID if c.LastStreamID == invalidStreamID { nextStreamID = 0 } else { nextStreamID = c.LastStreamID + 4 } c.streamMx.Unlock() // Streams with stream ID equal to or greater than the stream ID carried in the GOAWAY frame // will be rejected, see section 5.2 of RFC 9114. if maxStreamID != invalidStreamID && nextStreamID >= maxStreamID { return nil, errGoAway } str, err := c.OpenStreamSync(ctx) if err != nil { return nil, err } hstr := newStateTrackingStream(str, c, func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) c.streamMx.Lock() c.Streams[str.StreamID()] = hstr c.LastStreamID = str.StreamID() c.streamMx.Unlock() rsp := &http.Response{} trace := httptrace.ContextClientTrace(ctx) return newRequestStream( newStream(hstr, c, trace, func(r io.Reader, hf *headersFrame) error { hdr, err := c.decodeTrailers(r, str.StreamID(), hf, maxHeaderBytes) if err != nil { return err } rsp.Trailer = hdr return nil }, c.qlogger), requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes, rsp, ), nil } func (c *Conn) decodeTrailers(r io.Reader, streamID quic.StreamID, hf *headersFrame, maxHeaderBytes uint64) (http.Header, error) { if hf.Length > maxHeaderBytes { maybeQlogInvalidHeadersFrame(c.qlogger, streamID, hf.Length) return nil, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, maxHeaderBytes) } b := make([]byte, hf.Length) if _, err := io.ReadFull(r, b); err != nil { return nil, err } fields, err := c.decoder.DecodeFull(b) if err != nil { maybeQlogInvalidHeadersFrame(c.qlogger, streamID, hf.Length) return nil, err } if c.qlogger != nil { qlogParsedHeadersFrame(c.qlogger, streamID, hf, fields) } return parseTrailers(fields) } // only used by the server func (c *Conn) acceptStream(ctx context.Context) (*stateTrackingStream, error) { str, err := c.Conn.AcceptStream(ctx) if err != nil { return nil, err } strID := str.StreamID() hstr := newStateTrackingStream(str, c, func(b []byte) error { return c.sendDatagram(strID, b) }) c.streamMx.Lock() c.Streams[strID] = hstr if c.idleTimeout > 0 { if len(c.Streams) == 1 { c.idleTimer.Stop() } } c.streamMx.Unlock() return hstr, nil } func (c *Conn) CloseWithError(code quic.ApplicationErrorCode, msg string) error { if c.idleTimer != nil { c.idleTimer.Stop() } return c.Conn.CloseWithError(code, msg) } func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool)) { var ( rcvdControlStr atomic.Bool rcvdQPACKEncoderStr atomic.Bool rcvdQPACKDecoderStr atomic.Bool ) for { str, err := c.Conn.AcceptUniStream(context.Background()) if err != nil { if c.logger != nil { c.logger.Debug("accepting unidirectional stream failed", "error", err) } return } go func(str *quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { id := c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) if hijack != nil && hijack(StreamType(streamType), id, str, err) { return } if c.logger != nil { c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) } return } // We're only interested in the control stream here. switch streamType { case streamTypeControlStream: case streamTypeQPACKEncoderStream: if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypeQPACKDecoderStream: if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypePushStream: if c.isServer { // only the server can push c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") } else { // we never increased the Push ID, so we don't expect any push streams c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") } return default: if hijack != nil { if hijack( StreamType(streamType), c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID), str, nil, ) { return } } str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } // Only a single control stream is allowed. if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") return } c.handleControlStream(str) }(str) } } func (c *Conn) handleControlStream(str *quic.ReceiveStream) { fp := &frameParser{closeConn: c.Conn.CloseWithError, r: str, streamID: str.StreamID()} f, err := fp.ParseNext(c.qlogger) if err != nil { var serr *quic.StreamError if err == io.EOF || errors.As(err, &serr) { c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "") return } c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } c.settings = &Settings{ EnableDatagrams: sf.Datagram, EnableExtendedConnect: sf.ExtendedConnect, Other: sf.Other, } close(c.receivedSettings) if sf.Datagram { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") return } go func() { if err := c.receiveDatagrams(); err != nil { if c.logger != nil { c.logger.Debug("receiving datagrams failed", "error", err) } } }() } // we don't support server push, hence we don't expect any GOAWAY frames from the client if c.isServer { return } for { f, err := fp.ParseNext(c.qlogger) if err != nil { var serr *quic.StreamError if err == io.EOF || errors.As(err, &serr) { c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "") return } c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } // GOAWAY is the only frame allowed at this point: // * unexpected frames are ignored by the frame parser // * we don't support any extension that might add support for more frames goaway, ok := f.(*goAwayFrame) if !ok { c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") return } if goaway.StreamID%4 != 0 { // client-initiated, bidirectional streams c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return } c.streamMx.Lock() if c.maxStreamID != invalidStreamID && goaway.StreamID > c.maxStreamID { c.streamMx.Unlock() c.Conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return } c.maxStreamID = goaway.StreamID hasActiveStreams := len(c.Streams) > 0 c.streamMx.Unlock() // immediately close the connection if there are currently no active requests if !hasActiveStreams { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") return } } } func (c *Conn) sendDatagram(streamID quic.StreamID, b []byte) error { // TODO: this creates a lot of garbage and an additional copy data := make([]byte, 0, len(b)+8) quarterStreamID := uint64(streamID / 4) data = quicvarint.Append(data, uint64(streamID/4)) data = append(data, b...) if c.qlogger != nil { c.qlogger.RecordEvent(qlog.DatagramCreated{ QuaterStreamID: quarterStreamID, Raw: qlog.RawInfo{ Length: len(data), PayloadLength: len(b), }, }) } return c.Conn.SendDatagram(data) } func (c *Conn) receiveDatagrams() error { for { b, err := c.Conn.ReceiveDatagram(context.Background()) if err != nil { return err } quarterStreamID, n, err := quicvarint.Parse(b) if err != nil { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") return fmt.Errorf("could not read quarter stream id: %w", err) } if c.qlogger != nil { c.qlogger.RecordEvent(qlog.DatagramParsed{ QuaterStreamID: quarterStreamID, Raw: qlog.RawInfo{ Length: len(b), PayloadLength: len(b) - n, }, }) } if quarterStreamID > maxQuarterStreamID { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") return fmt.Errorf("invalid quarter stream id: %w", err) } streamID := quic.StreamID(4 * quarterStreamID) c.streamMx.Lock() dg, ok := c.Streams[streamID] c.streamMx.Unlock() if !ok { continue } dg.enqueueDatagram(b[n:]) } } // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. // Settings can be optained from the Settings method after the channel was closed. func (c *Conn) ReceivedSettings() <-chan struct{} { return c.receivedSettings } // Settings returns the settings received on this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. func (c *Conn) Settings() *Settings { return c.settings } // Context returns the context of the underlying QUIC connection. func (c *Conn) Context() context.Context { return c.ctx }