forked from quic-go/quic-go
handle stream-related frame in the streams map (#5212)
* handle stream-related frame in the streams map * remove stream manager interface and mock
This commit is contained in:
215
streams_map.go
215
streams_map.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/flowcontrol"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
@@ -37,8 +38,6 @@ type streamsMap struct {
|
||||
reset bool
|
||||
}
|
||||
|
||||
var _ streamManager = &streamsMap{}
|
||||
|
||||
func newStreamsMap(
|
||||
ctx context.Context,
|
||||
sender streamSender,
|
||||
@@ -180,90 +179,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
|
||||
str, err := m.getOrOpenReceiveStream(id)
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
|
||||
}
|
||||
return m.incomingUniStreams.GetOrOpenStream(id)
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str.ReceiveStream, err
|
||||
} else {
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str.ReceiveStream, err
|
||||
}
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
|
||||
str, err := m.getOrOpenSendStream(id)
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingUniStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return str, err
|
||||
}
|
||||
// an incoming unidirectional stream is a receive stream, not a send stream
|
||||
return nil, fmt.Errorf("peer attempted to open send stream %d", id)
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str.SendStream, nil
|
||||
} else {
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil && err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str.SendStream, nil
|
||||
}
|
||||
}
|
||||
panic("")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeUni:
|
||||
@@ -273,6 +188,134 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
|
||||
}
|
||||
}
|
||||
|
||||
type sendStreamFrameHandler interface {
|
||||
updateSendWindow(protocol.ByteCount)
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
}
|
||||
|
||||
func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
if id.InitiatedBy() != m.perspective {
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id),
|
||||
}
|
||||
}
|
||||
str, err := m.outgoingUniStreams.GetStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
case protocol.StreamTypeBidi:
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err := m.outgoingBidiStreams.GetStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error {
|
||||
str, err := m.getSendStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
str.updateSendWindow(f.MaximumStreamData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error {
|
||||
str, err := m.getSendStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
str.handleStopSendingFrame(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
type receiveStreamFrameHandler interface {
|
||||
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
|
||||
handleStreamFrame(*wire.StreamFrame, time.Time) error
|
||||
}
|
||||
|
||||
func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) {
|
||||
switch id.Type() {
|
||||
case protocol.StreamTypeUni:
|
||||
// an outgoing unidirectional stream is a send stream, not a receive stream
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.StreamStateError,
|
||||
ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id),
|
||||
}
|
||||
}
|
||||
str, err := m.incomingUniStreams.GetOrOpenStream(id)
|
||||
if err != nil || str == nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
case protocol.StreamTypeBidi:
|
||||
var str *Stream
|
||||
var err error
|
||||
if id.InitiatedBy() == m.perspective {
|
||||
str, err = m.outgoingBidiStreams.GetStream(id)
|
||||
} else {
|
||||
str, err = m.incomingBidiStreams.GetOrOpenStream(id)
|
||||
}
|
||||
if str == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error {
|
||||
if _, err := m.getReceiveStream(f.StreamID); err != nil {
|
||||
return err
|
||||
}
|
||||
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
|
||||
// but we need to make sure that the stream ID is valid.
|
||||
return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error {
|
||||
str, err := m.getReceiveStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
return str.handleResetStreamFrame(f, rcvTime)
|
||||
}
|
||||
|
||||
func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error {
|
||||
str, err := m.getReceiveStream(f.StreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if str == nil { // stream already deleted
|
||||
return nil
|
||||
}
|
||||
return str.handleStreamFrame(f, rcvTime)
|
||||
}
|
||||
|
||||
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
|
||||
m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
|
||||
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))
|
||||
|
||||
Reference in New Issue
Block a user