From 44243b4f522cf79b31c308f9d658ea4f4a297e1c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 31 Oct 2018 10:17:59 +0700 Subject: [PATCH] use the stream helper function in the streamsMap --- streams_map.go | 117 ++++++++++++++++++------------------------------- 1 file changed, 42 insertions(+), 75 deletions(-) diff --git a/streams_map.go b/streams_map.go index 584a7934b..101d4fa67 100644 --- a/streams_map.go +++ b/streams_map.go @@ -9,15 +9,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type streamType int - -const ( - streamTypeOutgoingBidi streamType = iota - streamTypeIncomingBidi - streamTypeOutgoingUni - streamTypeIncomingUni -) - type streamsMap struct { perspective protocol.Perspective @@ -93,33 +84,6 @@ func newStreamsMap( return m } -func (m *streamsMap) getStreamType(id protocol.StreamID) streamType { - if m.perspective == protocol.PerspectiveServer { - switch id % 4 { - case 0: - return streamTypeIncomingBidi - case 1: - return streamTypeOutgoingBidi - case 2: - return streamTypeIncomingUni - case 3: - return streamTypeOutgoingUni - } - } else { - switch id % 4 { - case 0: - return streamTypeOutgoingBidi - case 1: - return streamTypeIncomingBidi - case 2: - return streamTypeOutgoingUni - case 3: - return streamTypeIncomingUni - } - } - panic("") -} - func (m *streamsMap) OpenStream() (Stream, error) { return m.outgoingBidiStreams.OpenStream() } @@ -145,64 +109,67 @@ func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { } func (m *streamsMap) DeleteStream(id protocol.StreamID) error { - switch m.getStreamType(id) { - case streamTypeIncomingBidi: - return m.incomingBidiStreams.DeleteStream(id) - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.DeleteStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return m.outgoingUniStreams.DeleteStream(id) + } return m.incomingUniStreams.DeleteStream(id) - case streamTypeOutgoingUni: - return m.outgoingUniStreams.DeleteStream(id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.DeleteStream(id) + } + return m.incomingBidiStreams.DeleteStream(id) } + panic("") } func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.GetStream(id) - case streamTypeIncomingBidi: - return m.incomingBidiStreams.GetOrOpenStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + } return m.incomingUniStreams.GetOrOpenStream(id) - case streamTypeOutgoingUni: - // an outgoing unidirectional stream is a send stream, not a receive stream - return nil, fmt.Errorf("peer attempted to open receive stream %d", id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.GetStream(id) + } + return m.incomingBidiStreams.GetOrOpenStream(id) } + panic("") } func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.GetStream(id) - case streamTypeIncomingBidi: - return m.incomingBidiStreams.GetOrOpenStream(id) - case streamTypeOutgoingUni: - return m.outgoingUniStreams.GetStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return m.outgoingUniStreams.GetStream(id) + } // an incoming unidirectional stream is a receive stream, not a send stream return nil, fmt.Errorf("peer attempted to open send stream %d", id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.GetStream(id) + } + return m.incomingBidiStreams.GetOrOpenStream(id) } + panic("") } func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { id := f.StreamID - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - m.outgoingBidiStreams.SetMaxStream(id) - return nil - case streamTypeOutgoingUni: - m.outgoingUniStreams.SetMaxStream(id) - return nil - default: + if id.InitiatedBy() != m.perspective { return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) } + switch id.Type() { + case protocol.StreamTypeUni: + m.outgoingUniStreams.SetMaxStream(id) + case protocol.StreamTypeBidi: + m.outgoingBidiStreams.SetMaxStream(id) + } + return nil } func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {