handle stream-related frame in the streams map (#5212)

* handle stream-related frame in the streams map

* remove stream manager interface and mock
This commit is contained in:
Marten Seemann
2025-06-09 16:00:46 +08:00
committed by GitHub
parent 4f23ac2752
commit 1b07674b19
10 changed files with 529 additions and 1036 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
@@ -37,8 +38,6 @@ type streamsMap struct {
reset bool
}
var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
@@ -180,90 +179,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
panic("")
}
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
str, err := m.getOrOpenReceiveStream(id)
if err != nil {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return str, nil
}
func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() == m.perspective {
// an outgoing unidirectional stream is a send stream, not a receive stream
return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
}
return m.incomingUniStreams.GetOrOpenStream(id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
return str.ReceiveStream, err
} else {
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil && err == nil {
return nil, nil
}
return str.ReceiveStream, err
}
}
panic("")
}
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
str, err := m.getOrOpenSendStream(id)
if err != nil {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return str, nil
}
func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingUniStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
return str, err
}
// an incoming unidirectional stream is a receive stream, not a send stream
return nil, fmt.Errorf("peer attempted to open send stream %d", id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil && err == nil {
return nil, nil
}
if err != nil {
return nil, err
}
return str.SendStream, nil
} else {
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil && err == nil {
return nil, nil
}
if err != nil {
return nil, err
}
return str.SendStream, nil
}
}
panic("")
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
switch f.Type {
case protocol.StreamTypeUni:
@@ -273,6 +188,134 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
}
}
type sendStreamFrameHandler interface {
updateSendWindow(protocol.ByteCount)
handleStopSendingFrame(*wire.StopSendingFrame)
}
func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() != m.perspective {
// an outgoing unidirectional stream is a send stream, not a receive stream
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id),
}
}
str, err := m.outgoingUniStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.updateSendWindow(f.MaximumStreamData)
return nil
}
func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.handleStopSendingFrame(f)
return nil
}
type receiveStreamFrameHandler interface {
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
handleStreamFrame(*wire.StreamFrame, time.Time) error
}
func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
// an outgoing unidirectional stream is a send stream, not a receive stream
if id.InitiatedBy() == m.perspective {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id),
}
}
str, err := m.incomingUniStreams.GetOrOpenStream(id)
if err != nil || str == nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
var str *Stream
var err error
if id.InitiatedBy() == m.perspective {
str, err = m.outgoingBidiStreams.GetStream(id)
} else {
str, err = m.incomingBidiStreams.GetOrOpenStream(id)
}
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error {
if _, err := m.getReceiveStream(f.StreamID); err != nil {
return err
}
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
// but we need to make sure that the stream ID is valid.
return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame
}
func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleResetStreamFrame(f, rcvTime)
}
func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleStreamFrame(f, rcvTime)
}
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))