Merge pull request #1566 from lucas-clemente/better-stream-helpers

simplify the streamsMap by using the stream helper functions
This commit is contained in:
Marten Seemann
2018-11-01 16:22:20 +07:00
committed by GitHub
4 changed files with 63 additions and 83 deletions

View File

@@ -3,6 +3,16 @@ package protocol
// A StreamID in QUIC // A StreamID in QUIC
type StreamID uint64 type StreamID uint64
// StreamType encodes if this is a unidirectional or bidirectional stream
type StreamType uint8
const (
// StreamTypeUni is a unidirectional stream
StreamTypeUni StreamType = iota
// StreamTypeBidi is a bidirectional stream
StreamTypeBidi
)
// InitiatedBy says if the stream was initiated by the client or by the server // InitiatedBy says if the stream was initiated by the client or by the server
func (s StreamID) InitiatedBy() Perspective { func (s StreamID) InitiatedBy() Perspective {
if s%2 == 0 { if s%2 == 0 {
@@ -11,9 +21,12 @@ func (s StreamID) InitiatedBy() Perspective {
return PerspectiveServer return PerspectiveServer
} }
// IsUniDirectional says if this is a unidirectional stream (true) or not (false) //Type says if this is a unidirectional or bidirectional stream
func (s StreamID) IsUniDirectional() bool { func (s StreamID) Type() StreamType {
return s%4 >= 2 if s%4 >= 2 {
return StreamTypeUni
}
return StreamTypeBidi
} }
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open, // MaxBidiStreamID is the highest stream ID that the peer is allowed to open,

View File

@@ -14,10 +14,10 @@ var _ = Describe("Stream ID", func() {
}) })
It("tells the directionality", func() { It("tells the directionality", func() {
Expect(StreamID(4).IsUniDirectional()).To(BeFalse()) Expect(StreamID(4).Type()).To(Equal(StreamTypeBidi))
Expect(StreamID(5).IsUniDirectional()).To(BeFalse()) Expect(StreamID(5).Type()).To(Equal(StreamTypeBidi))
Expect(StreamID(6).IsUniDirectional()).To(BeTrue()) Expect(StreamID(6).Type()).To(Equal(StreamTypeUni))
Expect(StreamID(7).IsUniDirectional()).To(BeTrue()) Expect(StreamID(7).Type()).To(Equal(StreamTypeUni))
}) })
Context("maximum stream IDs", func() { Context("maximum stream IDs", func() {

View File

@@ -992,7 +992,7 @@ func (s *session) newStream(id protocol.StreamID) streamI {
func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {
var initialSendWindow protocol.ByteCount var initialSendWindow protocol.ByteCount
if s.peerParams != nil { if s.peerParams != nil {
if id.IsUniDirectional() { if id.Type() == protocol.StreamTypeUni {
initialSendWindow = s.peerParams.InitialMaxStreamDataUni initialSendWindow = s.peerParams.InitialMaxStreamDataUni
} else { } else {
if id.InitiatedBy() == s.perspective { if id.InitiatedBy() == s.perspective {

View File

@@ -9,15 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
type streamType int
const (
streamTypeOutgoingBidi streamType = iota
streamTypeIncomingBidi
streamTypeOutgoingUni
streamTypeIncomingUni
)
type streamsMap struct { type streamsMap struct {
perspective protocol.Perspective perspective protocol.Perspective
@@ -93,33 +84,6 @@ func newStreamsMap(
return m return m
} }
func (m *streamsMap) getStreamType(id protocol.StreamID) streamType {
if m.perspective == protocol.PerspectiveServer {
switch id % 4 {
case 0:
return streamTypeIncomingBidi
case 1:
return streamTypeOutgoingBidi
case 2:
return streamTypeIncomingUni
case 3:
return streamTypeOutgoingUni
}
} else {
switch id % 4 {
case 0:
return streamTypeOutgoingBidi
case 1:
return streamTypeIncomingBidi
case 2:
return streamTypeOutgoingUni
case 3:
return streamTypeIncomingUni
}
}
panic("")
}
func (m *streamsMap) OpenStream() (Stream, error) { func (m *streamsMap) OpenStream() (Stream, error) {
return m.outgoingBidiStreams.OpenStream() return m.outgoingBidiStreams.OpenStream()
} }
@@ -145,64 +109,67 @@ func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
} }
func (m *streamsMap) DeleteStream(id protocol.StreamID) error { func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeIncomingBidi: case protocol.StreamTypeUni:
return m.incomingBidiStreams.DeleteStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeOutgoingBidi: return m.outgoingUniStreams.DeleteStream(id)
return m.outgoingBidiStreams.DeleteStream(id) }
case streamTypeIncomingUni:
return m.incomingUniStreams.DeleteStream(id) return m.incomingUniStreams.DeleteStream(id)
case streamTypeOutgoingUni: case protocol.StreamTypeBidi:
return m.outgoingUniStreams.DeleteStream(id) if id.InitiatedBy() == m.perspective {
default: return m.outgoingBidiStreams.DeleteStream(id)
panic("invalid stream type") }
return m.incomingBidiStreams.DeleteStream(id)
} }
panic("")
} }
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeOutgoingBidi: case protocol.StreamTypeUni:
return m.outgoingBidiStreams.GetStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeIncomingBidi: // an outgoing unidirectional stream is a send stream, not a receive stream
return m.incomingBidiStreams.GetOrOpenStream(id) return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
case streamTypeIncomingUni: }
return m.incomingUniStreams.GetOrOpenStream(id) return m.incomingUniStreams.GetOrOpenStream(id)
case streamTypeOutgoingUni: case protocol.StreamTypeBidi:
// an outgoing unidirectional stream is a send stream, not a receive stream if id.InitiatedBy() == m.perspective {
return nil, fmt.Errorf("peer attempted to open receive stream %d", id) return m.outgoingBidiStreams.GetStream(id)
default: }
panic("invalid stream type") return m.incomingBidiStreams.GetOrOpenStream(id)
} }
panic("")
} }
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeOutgoingBidi: case protocol.StreamTypeUni:
return m.outgoingBidiStreams.GetStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeIncomingBidi: return m.outgoingUniStreams.GetStream(id)
return m.incomingBidiStreams.GetOrOpenStream(id) }
case streamTypeOutgoingUni:
return m.outgoingUniStreams.GetStream(id)
case streamTypeIncomingUni:
// an incoming unidirectional stream is a receive stream, not a send stream // an incoming unidirectional stream is a receive stream, not a send stream
return nil, fmt.Errorf("peer attempted to open send stream %d", id) return nil, fmt.Errorf("peer attempted to open send stream %d", id)
default: case protocol.StreamTypeBidi:
panic("invalid stream type") if id.InitiatedBy() == m.perspective {
return m.outgoingBidiStreams.GetStream(id)
}
return m.incomingBidiStreams.GetOrOpenStream(id)
} }
panic("")
} }
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
id := f.StreamID id := f.StreamID
switch m.getStreamType(id) { if id.InitiatedBy() != m.perspective {
case streamTypeOutgoingBidi:
m.outgoingBidiStreams.SetMaxStream(id)
return nil
case streamTypeOutgoingUni:
m.outgoingUniStreams.SetMaxStream(id)
return nil
default:
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
} }
switch id.Type() {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(id)
case protocol.StreamTypeBidi:
m.outgoingBidiStreams.SetMaxStream(id)
}
return nil
} }
func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {