diff --git a/internal/protocol/stream_id.go b/internal/protocol/stream_id.go index 49c4df7e..0e0ab7e9 100644 --- a/internal/protocol/stream_id.go +++ b/internal/protocol/stream_id.go @@ -3,6 +3,16 @@ package protocol // A StreamID in QUIC type StreamID uint64 +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + // InitiatedBy says if the stream was initiated by the client or by the server func (s StreamID) InitiatedBy() Perspective { if s%2 == 0 { @@ -11,9 +21,12 @@ func (s StreamID) InitiatedBy() Perspective { return PerspectiveServer } -// IsUniDirectional says if this is a unidirectional stream (true) or not (false) -func (s StreamID) IsUniDirectional() bool { - return s%4 >= 2 +//Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi } // MaxBidiStreamID is the highest stream ID that the peer is allowed to open, diff --git a/internal/protocol/stream_id_test.go b/internal/protocol/stream_id_test.go index 72adac35..e19823d3 100644 --- a/internal/protocol/stream_id_test.go +++ b/internal/protocol/stream_id_test.go @@ -14,10 +14,10 @@ var _ = Describe("Stream ID", func() { }) It("tells the directionality", func() { - Expect(StreamID(4).IsUniDirectional()).To(BeFalse()) - Expect(StreamID(5).IsUniDirectional()).To(BeFalse()) - Expect(StreamID(6).IsUniDirectional()).To(BeTrue()) - Expect(StreamID(7).IsUniDirectional()).To(BeTrue()) + Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi)) + Expect(StreamID(6).Type()).To(Equal(StreamTypeUni)) + Expect(StreamID(7).Type()).To(Equal(StreamTypeUni)) }) Context("maximum stream IDs", func() { diff --git a/session.go b/session.go index d6a2716b..12e164fc 100644 --- a/session.go +++ b/session.go @@ -992,7 +992,7 @@ func (s *session) newStream(id protocol.StreamID) streamI { func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { var initialSendWindow protocol.ByteCount if s.peerParams != nil { - if id.IsUniDirectional() { + if id.Type() == protocol.StreamTypeUni { initialSendWindow = s.peerParams.InitialMaxStreamDataUni } else { if id.InitiatedBy() == s.perspective { diff --git a/streams_map.go b/streams_map.go index 584a7934..101d4fa6 100644 --- a/streams_map.go +++ b/streams_map.go @@ -9,15 +9,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type streamType int - -const ( - streamTypeOutgoingBidi streamType = iota - streamTypeIncomingBidi - streamTypeOutgoingUni - streamTypeIncomingUni -) - type streamsMap struct { perspective protocol.Perspective @@ -93,33 +84,6 @@ func newStreamsMap( return m } -func (m *streamsMap) getStreamType(id protocol.StreamID) streamType { - if m.perspective == protocol.PerspectiveServer { - switch id % 4 { - case 0: - return streamTypeIncomingBidi - case 1: - return streamTypeOutgoingBidi - case 2: - return streamTypeIncomingUni - case 3: - return streamTypeOutgoingUni - } - } else { - switch id % 4 { - case 0: - return streamTypeOutgoingBidi - case 1: - return streamTypeIncomingBidi - case 2: - return streamTypeOutgoingUni - case 3: - return streamTypeIncomingUni - } - } - panic("") -} - func (m *streamsMap) OpenStream() (Stream, error) { return m.outgoingBidiStreams.OpenStream() } @@ -145,64 +109,67 @@ func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { } func (m *streamsMap) DeleteStream(id protocol.StreamID) error { - switch m.getStreamType(id) { - case streamTypeIncomingBidi: - return m.incomingBidiStreams.DeleteStream(id) - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.DeleteStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return m.outgoingUniStreams.DeleteStream(id) + } return m.incomingUniStreams.DeleteStream(id) - case streamTypeOutgoingUni: - return m.outgoingUniStreams.DeleteStream(id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.DeleteStream(id) + } + return m.incomingBidiStreams.DeleteStream(id) } + panic("") } func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.GetStream(id) - case streamTypeIncomingBidi: - return m.incomingBidiStreams.GetOrOpenStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + } return m.incomingUniStreams.GetOrOpenStream(id) - case streamTypeOutgoingUni: - // an outgoing unidirectional stream is a send stream, not a receive stream - return nil, fmt.Errorf("peer attempted to open receive stream %d", id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.GetStream(id) + } + return m.incomingBidiStreams.GetOrOpenStream(id) } + panic("") } func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - return m.outgoingBidiStreams.GetStream(id) - case streamTypeIncomingBidi: - return m.incomingBidiStreams.GetOrOpenStream(id) - case streamTypeOutgoingUni: - return m.outgoingUniStreams.GetStream(id) - case streamTypeIncomingUni: + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return m.outgoingUniStreams.GetStream(id) + } // an incoming unidirectional stream is a receive stream, not a send stream return nil, fmt.Errorf("peer attempted to open send stream %d", id) - default: - panic("invalid stream type") + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return m.outgoingBidiStreams.GetStream(id) + } + return m.incomingBidiStreams.GetOrOpenStream(id) } + panic("") } func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { id := f.StreamID - switch m.getStreamType(id) { - case streamTypeOutgoingBidi: - m.outgoingBidiStreams.SetMaxStream(id) - return nil - case streamTypeOutgoingUni: - m.outgoingUniStreams.SetMaxStream(id) - return nil - default: + if id.InitiatedBy() != m.perspective { return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) } + switch id.Type() { + case protocol.StreamTypeUni: + m.outgoingUniStreams.SetMaxStream(id) + case protocol.StreamTypeBidi: + m.outgoingBidiStreams.SetMaxStream(id) + } + return nil } func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {