http3: pass tracing ID instead of quic.Connection to stream hijackers (#4401)

The stream hijackers only need to be able to associate the stream with
the underlying QUIC connection. They are not supposed to call any
functions on the quic.Connection. As such, the better API is to just
pass them a unique identifier.
This commit is contained in:
Marten Seemann
2024-04-02 17:23:40 +13:00
committed by GitHub
parent 27a06f32ce
commit 183d42a729
6 changed files with 80 additions and 27 deletions

View File

@@ -17,7 +17,7 @@ type connection struct {
logger utils.Logger
enableDatagrams bool
uniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
settings *Settings
receivedSettings chan struct{}
@@ -26,7 +26,7 @@ type connection struct {
func newConnection(
quicConn quic.Connection,
enableDatagrams bool,
uniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
perspective protocol.Perspective,
logger utils.Logger,
) *connection {
@@ -57,7 +57,8 @@ func (c *connection) HandleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, err) {
id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), id, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
@@ -89,8 +90,15 @@ func (c *connection) HandleUnidirectionalStreams() {
}
return
default:
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, nil) {
return
if c.uniStreamHijacker != nil {
if c.uniStreamHijacker(
StreamType(streamType),
c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
str,
nil,
) {
return
}
}
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
return