From 41fce88a63c64982be2eb2d58b12313643c47085 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 8 Jun 2025 20:27:14 +0800 Subject: [PATCH] simplify stream ID handling in the outgoing streams map (#5209) No functional change expected. --- streams_map.go | 81 ++++------- streams_map_outgoing.go | 69 +++++----- streams_map_outgoing_test.go | 252 ++++++++++++++++++++--------------- streams_map_test.go | 8 -- 4 files changed, 209 insertions(+), 201 deletions(-) diff --git a/streams_map.go b/streams_map.go index 53bf7f3ff..16340d1a8 100644 --- a/streams_map.go +++ b/streams_map.go @@ -11,27 +11,6 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) -type streamError struct { - message string - nums []protocol.StreamNum -} - -func (e streamError) Error() string { - return e.message -} - -func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { - strError, ok := err.(streamError) - if !ok { - return err - } - ids := make([]interface{}, len(strError.nums)) - for i, num := range strError.nums { - ids[i] = num.StreamID(stype, pers) - } - return fmt.Errorf(strError.Error(), ids...) -} - // StreamLimitReachedError is returned from Connection.OpenStream and Connection.OpenUniStream // when it is not possible to open a new stream because the number of opens streams reached // the peer's stream limit. @@ -85,11 +64,11 @@ func newStreamsMap( func (m *streamsMap) initMaps() { m.outgoingBidiStreams = newOutgoingStreamsMap( protocol.StreamTypeBidi, - func(num protocol.StreamNum) *Stream { - id := num.StreamID(protocol.StreamTypeBidi, m.perspective) + func(id protocol.StreamID) *Stream { return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.queueControlFrame, + m.perspective, ) m.incomingBidiStreams = newIncomingStreamsMap( protocol.StreamTypeBidi, @@ -102,11 +81,11 @@ func (m *streamsMap) initMaps() { ) m.outgoingUniStreams = newOutgoingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *SendStream { - id := num.StreamID(protocol.StreamTypeUni, m.perspective) + func(id protocol.StreamID) *SendStream { return newSendStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.queueControlFrame, + m.perspective, ) m.incomingUniStreams = newIncomingStreamsMap( protocol.StreamTypeUni, @@ -127,8 +106,7 @@ func (m *streamsMap) OpenStream() (*Stream, error) { if reset { return nil, Err0RTTRejected } - str, err := mm.OpenStream() - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) + return mm.OpenStream() } func (m *streamsMap) OpenStreamSync(ctx context.Context) (*Stream, error) { @@ -139,8 +117,7 @@ func (m *streamsMap) OpenStreamSync(ctx context.Context) (*Stream, error) { if reset { return nil, Err0RTTRejected } - str, err := mm.OpenStreamSync(ctx) - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) + return mm.OpenStreamSync(ctx) } func (m *streamsMap) OpenUniStream() (*SendStream, error) { @@ -151,8 +128,7 @@ func (m *streamsMap) OpenUniStream() (*SendStream, error) { if reset { return nil, Err0RTTRejected } - str, err := mm.OpenStream() - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) + return mm.OpenStream() } func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (*SendStream, error) { @@ -163,8 +139,7 @@ func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (*SendStream, error) if reset { return nil, Err0RTTRejected } - str, err := mm.OpenStreamSync(ctx) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + return mm.OpenStreamSync(ctx) } func (m *streamsMap) AcceptStream(ctx context.Context) (*Stream, error) { @@ -175,8 +150,7 @@ func (m *streamsMap) AcceptStream(ctx context.Context) (*Stream, error) { if reset { return nil, Err0RTTRejected } - str, err := mm.AcceptStream(ctx) - return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) + return mm.AcceptStream(ctx) } func (m *streamsMap) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) { @@ -187,21 +161,19 @@ func (m *streamsMap) AcceptUniStream(ctx context.Context) (*ReceiveStream, error if reset { return nil, Err0RTTRejected } - str, err := mm.AcceptStream(ctx) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) + return mm.AcceptStream(ctx) } func (m *streamsMap) DeleteStream(id protocol.StreamID) error { - num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { - return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) + return m.outgoingUniStreams.DeleteStream(id) } return m.incomingUniStreams.DeleteStream(id) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { - return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) + return m.outgoingBidiStreams.DeleteStream(id) } return m.incomingBidiStreams.DeleteStream(id) } @@ -220,28 +192,26 @@ func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStrea } func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStream, error) { - num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { // an outgoing unidirectional stream is a send stream, not a receive stream return nil, fmt.Errorf("peer attempted to open receive stream %d", id) } - str, err := m.incomingUniStreams.GetOrOpenStream(id) - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + return m.incomingUniStreams.GetOrOpenStream(id) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { - str, err := m.outgoingBidiStreams.GetStream(num) + str, err := m.outgoingBidiStreams.GetStream(id) if str == nil && err == nil { return nil, nil } - return str.ReceiveStream, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + return str.ReceiveStream, err } else { str, err := m.incomingBidiStreams.GetOrOpenStream(id) if str == nil && err == nil { return nil, nil } - return str.ReceiveStream, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + return str.ReceiveStream, err } } panic("") @@ -259,26 +229,25 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, err } func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) { - num := id.StreamNum() switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { - str, err := m.outgoingUniStreams.GetStream(num) + str, err := m.outgoingUniStreams.GetStream(id) if str == nil && err == nil { return nil, nil } - return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + return str, err } // an incoming unidirectional stream is a receive stream, not a send stream return nil, fmt.Errorf("peer attempted to open send stream %d", id) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { - str, err := m.outgoingBidiStreams.GetStream(num) + str, err := m.outgoingBidiStreams.GetStream(id) if str == nil && err == nil { return nil, nil } if err != nil { - return nil, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + return nil, err } return str.SendStream, nil } else { @@ -287,7 +256,7 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, err return nil, nil } if err != nil { - return nil, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + return nil, err } return str.SendStream, nil } @@ -298,17 +267,17 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, err func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { switch f.Type { case protocol.StreamTypeUni: - m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) + m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeUni, m.perspective)) case protocol.StreamTypeBidi: - m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) + m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) } } func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) - m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) - m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum.StreamID(protocol.StreamTypeUni, m.perspective)) } func (m *streamsMap) CloseWithError(err error) { diff --git a/streams_map_outgoing.go b/streams_map_outgoing.go index d4c182592..248390e3a 100644 --- a/streams_map_outgoing.go +++ b/streams_map_outgoing.go @@ -2,6 +2,7 @@ package quic import ( "context" + "fmt" "slices" "sync" @@ -18,15 +19,15 @@ type outgoingStreamsMap[T outgoingStream] struct { mutex sync.RWMutex streamType protocol.StreamType - streams map[protocol.StreamNum]T + streams map[protocol.StreamID]T openQueue []chan struct{} - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamNum) T + newStream func(protocol.StreamID) T queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error @@ -34,14 +35,26 @@ type outgoingStreamsMap[T outgoingStream] struct { func newOutgoingStreamsMap[T outgoingStream]( streamType protocol.StreamType, - newStream func(protocol.StreamNum) T, + newStream func(protocol.StreamID) T, queueControlFrame func(wire.Frame), + pers protocol.Perspective, ) *outgoingStreamsMap[T] { + var nextStream protocol.StreamID + switch { + case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer: + nextStream = protocol.FirstOutgoingBidiStreamServer + case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient: + nextStream = protocol.FirstOutgoingBidiStreamClient + case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer: + nextStream = protocol.FirstOutgoingUniStreamServer + case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient: + nextStream = protocol.FirstOutgoingUniStreamClient + } return &outgoingStreamsMap[T]{ streamType: streamType, - streams: make(map[protocol.StreamNum]T), + streams: make(map[protocol.StreamID]T), maxStream: protocol.InvalidStreamNum, - nextStream: 1, + nextStream: nextStream, newStream: newStream, queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, } @@ -114,7 +127,7 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { func (m *outgoingStreamsMap[T]) openStream() T { s := m.newStream(m.nextStream) m.streams[m.nextStream] = s - m.nextStream++ + m.nextStream += 4 return s } @@ -125,55 +138,49 @@ func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() { return } - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream + var streamLimit protocol.StreamNum + if m.maxStream != protocol.InvalidStreamID { + streamLimit = m.maxStream.StreamNum() } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ Type: m.streamType, - StreamLimit: streamNum, + StreamLimit: streamLimit, }) m.blockedSent = true } -func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) { +func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) { m.mutex.RLock() - if num >= m.nextStream { + if id >= m.nextStream { m.mutex.RUnlock() - return *new(T), streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } + return *new(T), fmt.Errorf("peer attempted to open stream %d", id) } - s := m.streams[num] + s := m.streams[id] m.mutex.RUnlock() return s, nil } -func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { +func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("tried to delete unknown outgoing stream %d", id) } - delete(m.streams, num) + delete(m.streams, id) return nil } -func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) { +func (m *outgoingStreamsMap[T]) SetMaxStream(id protocol.StreamID) { m.mutex.Lock() defer m.mutex.Unlock() - if num <= m.maxStream { + if id <= m.maxStream { return } - m.maxStream = num + m.maxStream = id m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + if m.maxStream < m.nextStream-4+4*protocol.StreamID(len(m.openQueue)) { m.maybeSendBlockedFrame() } m.maybeUnblockOpenSync() diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 6e157cd23..3009a0af5 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -15,77 +15,80 @@ import ( "github.com/stretchr/testify/require" ) -type mockGenericStream struct { - num protocol.StreamNum - - closed bool - closeErr error - sendWindow protocol.ByteCount -} - -func (s *mockGenericStream) closeForShutdown(err error) { - s.closed = true - s.closeErr = err -} - -func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { - s.sendWindow = limit -} - func TestStreamsMapOutgoingOpenAndDelete(t *testing.T) { + t.Run("client", func(t *testing.T) { + testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveClient, protocol.FirstOutgoingBidiStreamClient) + }) + t.Run("server", func(t *testing.T) { + testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveServer, protocol.FirstOutgoingBidiStreamServer) + }) +} + +func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { m := newOutgoingStreamsMap( protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, + perspective, ) - m.SetMaxStream(protocol.MaxStreamCount) + m.SetMaxStream(protocol.MaxStreamID) - _, err := m.GetStream(1) - require.Error(t, err) - require.ErrorContains(t, err.(streamError).TestError(), "peer attempted to open stream 1") + _, err := m.GetStream(firstStream) + require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream)) str1, err := m.OpenStream() require.NoError(t, err) - require.Equal(t, protocol.StreamNum(1), str1.num) - s, err := m.GetStream(1) + require.Equal(t, firstStream, str1.id) + s, err := m.GetStream(firstStream) require.NoError(t, err) require.Equal(t, s, str1) str2, err := m.OpenStream() require.NoError(t, err) - require.Equal(t, protocol.StreamNum(2), str2.num) + require.Equal(t, firstStream+4, str2.id) // update send window m.UpdateSendWindow(1000) require.Equal(t, protocol.ByteCount(1000), str1.sendWindow) require.Equal(t, protocol.ByteCount(1000), str2.sendWindow) - err = m.DeleteStream(1337) + err = m.DeleteStream(firstStream + 1337*4) require.Error(t, err) - require.ErrorContains(t, err.(streamError).TestError(), "tried to delete unknown outgoing stream 1337") + require.ErrorContains(t, err, "tried to delete unknown outgoing stream") - require.NoError(t, m.DeleteStream(1)) + require.NoError(t, m.DeleteStream(firstStream)) // deleting the same stream twice will fail - err = m.DeleteStream(1) - require.Error(t, err) - require.ErrorContains(t, err.(streamError).TestError(), "tried to delete unknown outgoing stream 1") + require.ErrorContains(t, + m.DeleteStream(firstStream), + "tried to delete unknown outgoing stream", + ) // after deleting the stream it's not available anymore - str, err := m.GetStream(1) + str, err := m.GetStream(firstStream) require.NoError(t, err) require.Nil(t, str) } func TestStreamsMapOutgoingLimits(t *testing.T) { + t.Run("client", func(t *testing.T) { + testStreamsMapOutgoingLimits(t, protocol.PerspectiveClient, protocol.FirstOutgoingUniStreamClient) + }) + t.Run("server", func(t *testing.T) { + testStreamsMapOutgoingLimits(t, protocol.PerspectiveServer, protocol.FirstOutgoingUniStreamServer) + }) +} + +func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { m := newOutgoingStreamsMap( - protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + protocol.StreamTypeUni, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, + perspective, ) - m.SetMaxStream(1) + m.SetMaxStream(firstStream) str, err := m.OpenStream() require.NoError(t, err) - require.Equal(t, protocol.StreamNum(1), str.num) + require.Equal(t, firstStream, str.id) // We've now reached the limit. OpenStream returns an error _, err = m.OpenStream() @@ -124,17 +127,17 @@ func TestStreamsMapOutgoingLimits(t *testing.T) { } // ... or until it's possible to open a new stream - var openedStream *mockGenericStream + var openedStream *mockStream go func() { str, err := m.OpenStreamSync(context.Background()) openedStream = str errChan <- err }() - m.SetMaxStream(2) + m.SetMaxStream(firstStream + 4) select { case err := <-errChan: require.NoError(t, err) - require.Equal(t, protocol.StreamNum(2), openedStream.num) + require.Equal(t, firstStream+4, openedStream.id) case <-time.After(time.Second): t.Fatal("OpenStreamSync did not return after the stream limit was increased") } @@ -142,40 +145,40 @@ func TestStreamsMapOutgoingLimits(t *testing.T) { func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) { m := newOutgoingStreamsMap( - protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + protocol.StreamTypeUni, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, + protocol.PerspectiveClient, ) type result struct { index int - stream *mockGenericStream + stream *mockStream err error } - results := make(chan result, 3) for i := range 3 { go func(i int) { str, err := m.OpenStreamSync(context.Background()) - results <- result{index: i + 1, stream: str, err: err} + results <- result{index: i, stream: str, err: err} }(i) time.Sleep(scaleDuration(10 * time.Millisecond)) } - m.SetMaxStream(2) - received := make(map[protocol.StreamNum]struct{}) + m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 4) + received := make(map[protocol.StreamID]struct{}) for range 2 { select { case res := <-results: require.NoError(t, res.err) - require.Equal(t, res.index, int(res.stream.num)) - received[res.stream.num] = struct{}{} + require.Equal(t, protocol.FirstOutgoingUniStreamClient+4*protocol.StreamID(res.index), res.stream.id) + received[res.stream.id] = struct{}{} case <-time.After(time.Second): t.Fatal("OpenStreamSync did not return after the stream limit was increased") } } - require.Contains(t, received, protocol.StreamNum(1)) - require.Contains(t, received, protocol.StreamNum(2)) + require.Contains(t, received, protocol.FirstOutgoingUniStreamClient) + require.Contains(t, received, protocol.FirstOutgoingUniStreamClient+4) // the call to stream 3 is still blocked select { @@ -183,11 +186,11 @@ func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) { t.Fatal("expected OpenStreamSync to be blocked") case <-time.After(scaleDuration(10 * time.Millisecond)): } - m.SetMaxStream(3) + m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 8) select { case res := <-results: require.NoError(t, res.err) - require.Equal(t, protocol.StreamNum(3), res.stream.num) + require.Equal(t, protocol.FirstOutgoingUniStreamClient+8, res.stream.id) case <-time.After(time.Second): t.Fatal("OpenStreamSync did not return after the stream limit was increased") } @@ -195,12 +198,13 @@ func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) { func TestStreamsMapOutgoingClosing(t *testing.T) { m := newOutgoingStreamsMap( - protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + protocol.StreamTypeUni, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, + protocol.PerspectiveServer, ) - m.SetMaxStream(2) + m.SetMaxStream(protocol.FirstOutgoingUniStreamServer + 4) str1, err := m.OpenStream() require.NoError(t, err) str2, err := m.OpenStream() @@ -231,11 +235,12 @@ func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { var frameQueue []wire.Frame m := newOutgoingStreamsMap( protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, + protocol.PerspectiveClient, ) - m.SetMaxStream(3) + m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 8) for range 3 { _, err := m.OpenStream() require.NoError(t, err) @@ -250,9 +255,11 @@ func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { frameQueue = frameQueue[:0] // only a single STREAMS_BLOCKED frame is queued per offset - _, err = m.OpenStream() - require.ErrorIs(t, err, &StreamLimitReachedError{}) - require.Empty(t, frameQueue) + for range 5 { + _, err = m.OpenStream() + require.ErrorIs(t, err, &StreamLimitReachedError{}) + require.Empty(t, frameQueue) + } errChan := make(chan error, 3) for range 3 { @@ -264,7 +271,7 @@ func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { time.Sleep(scaleDuration(10 * time.Millisecond)) // allow 2 more streams - m.SetMaxStream(5) + m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 16) for range 2 { select { case err := <-errChan: @@ -279,7 +286,7 @@ func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { frameQueue = frameQueue[:0] // now accept the last stream - m.SetMaxStream(6) + m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 20) select { case err := <-errChan: require.NoError(t, err) @@ -290,25 +297,31 @@ func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { } func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) { + streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] + firstStream := protocol.FirstOutgoingUniStreamServer + if streamType == protocol.StreamTypeBidi { + firstStream = protocol.FirstOutgoingBidiStreamServer + } + const n = 100 frameQueue := make(chan wire.Frame, n) m := newOutgoingStreamsMap( - protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + streamType, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) { frameQueue <- f }, + protocol.PerspectiveServer, ) type result struct { - num protocol.StreamNum + id protocol.StreamID err error } - resultChan := make(chan result, n) for range n { go func() { str, err := m.OpenStreamSync(context.Background()) - resultChan <- result{num: str.num, err: err} + resultChan <- result{id: str.id, err: err} }() } @@ -320,44 +333,50 @@ func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) { t.Fatal("timed out waiting for STREAMS_BLOCKED frame") } - var limit int - var limits []protocol.StreamNum - seen := make(map[protocol.StreamNum]struct{}) - for limit < n { - add := rand.IntN(n/5) + 1 + limit := firstStream - 4 + var limits []protocol.StreamID + seen := make(map[protocol.StreamID]struct{}) + maxStream := firstStream + 4*(n-1) + for limit < maxStream { + add := 4 * protocol.StreamID(rand.IntN(n/5)+1) limit += add - if limit <= n { - limits = append(limits, protocol.StreamNum(limit)) + if limit <= maxStream { + limits = append(limits, limit) } t.Logf("setting stream limit to %d", limit) - m.SetMaxStream(protocol.StreamNum(limit)) + m.SetMaxStream(limit) - for range min(add, n-(limit-add)) { + loop: + for { select { case res := <-resultChan: require.NoError(t, res.err) - require.NotContains(t, seen, res.num) - seen[res.num] = struct{}{} + require.NotContains(t, seen, res.id) + require.LessOrEqual(t, res.id, limit) + seen[res.id] = struct{}{} + if len(seen) == int(limit.StreamNum()) || len(seen) == n { + break loop + } case <-time.After(time.Second): t.Fatalf("timed out waiting for stream to open") } } str, err := m.OpenStream() - if limit <= n { + if limit <= maxStream { require.ErrorIs(t, err, &StreamLimitReachedError{}) } else { require.NoError(t, err) - require.Equal(t, protocol.StreamNum(n+1), str.num) + require.Equal(t, maxStream+4, str.id) } } require.Len(t, seen, n) - var blockedAt []protocol.StreamNum close(frameQueue) + var blockedAt []protocol.StreamID for f := range frameQueue { if l := f.(*wire.StreamsBlockedFrame).StreamLimit; l <= n { - blockedAt = append(blockedAt, l) + blockedAt = append(blockedAt, l.StreamID(streamType, protocol.PerspectiveServer)) } } require.Equal(t, limits, blockedAt) @@ -366,15 +385,22 @@ func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) { func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) { const n = 100 - var frameQueue []wire.Frame + streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] + firstStream := protocol.FirstOutgoingUniStreamClient + if streamType == protocol.StreamTypeBidi { + firstStream = protocol.FirstOutgoingBidiStreamClient + } + + frameQueue := make(chan wire.Frame, n) m := newOutgoingStreamsMap( - protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, - func(f wire.Frame) { frameQueue = append(frameQueue, f) }, + streamType, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, + func(f wire.Frame) { frameQueue <- f }, + protocol.PerspectiveClient, ) type result struct { - str *mockGenericStream + str *mockStream err error } @@ -390,7 +416,7 @@ func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) { count++ } go func() { - var str *mockGenericStream + var str *mockStream var err error if shouldCancel { str, err = m.OpenStreamSync(ctx) @@ -401,32 +427,42 @@ func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) { }() } + select { + case f := <-frameQueue: + require.IsType(t, &wire.StreamsBlockedFrame{}, f) + require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit) + case <-time.After(time.Second): + t.Fatal("timed out waiting for STREAMS_BLOCKED frame") + } + time.Sleep(scaleDuration(10 * time.Millisecond)) cancel() - var limit int - limits := []protocol.StreamNum{0} - seen := make(map[protocol.StreamNum]struct{}) - var lastStreamSeen protocol.StreamNum + limit := firstStream - 4 + maxStream := firstStream + 4*(n-1) + var limits []protocol.StreamID + seen := make(map[protocol.StreamID]struct{}) + var lastStreamSeen protocol.StreamID var numCancelledSeen int - for limit < n { - limit += rand.IntN(n/5) + 1 - if limit < n { - limits = append(limits, protocol.StreamNum(limit)) + for limit < maxStream { + add := 4 * protocol.StreamID(rand.IntN(n/5)+1) + limit += add + if limit < maxStream { + limits = append(limits, limit) } t.Logf("setting stream limit to %d", limit) - m.SetMaxStream(protocol.StreamNum(limit)) + m.SetMaxStream(limit) - for lastStreamSeen < min(n, protocol.StreamNum(limit)) { + for lastStreamSeen < min(maxStream, limit) { select { case res := <-resultChan: if errors.Is(res.err, context.Canceled) { numCancelledSeen++ } else { require.NoError(t, res.err) - require.NotContains(t, seen, res.str.num) - seen[res.str.num] = struct{}{} - lastStreamSeen = res.str.num + require.NotContains(t, seen, res.str.id) + seen[res.str.id] = struct{}{} + lastStreamSeen = res.str.id } case <-time.After(time.Second): t.Fatalf("timed out waiting for stream to open") @@ -436,9 +472,12 @@ func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) { require.Len(t, seen, n) require.Equal(t, numCancelled, numCancelledSeen) - var blockedAt []protocol.StreamNum - for _, f := range frameQueue { - blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit) + close(frameQueue) + var blockedAt []protocol.StreamID + for f := range frameQueue { + sbf := f.(*wire.StreamsBlockedFrame) + require.Equal(t, streamType, sbf.Type) + blockedAt = append(blockedAt, sbf.StreamLimit.StreamID(streamType, protocol.PerspectiveClient)) } require.Equal(t, limits, blockedAt) } @@ -454,8 +493,9 @@ func TestStreamsMapConcurrent(t *testing.T) { func testStreamsMapConcurrent(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeBidi, - func(n protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: n} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, + protocol.PerspectiveClient, ) const num = 100 @@ -473,7 +513,7 @@ func testStreamsMapConcurrent(t *testing.T) { time.Sleep(scaleDuration(5 * time.Millisecond)) go m.CloseWithError(assert.AnError) go cancel() - go m.SetMaxStream(protocol.StreamNum(num / 2)) + go m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 4*num/2) for range num { select { diff --git a/streams_map_test.go b/streams_map_test.go index 8298b0ee9..80f58dee0 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -17,14 +17,6 @@ import ( "go.uber.org/mock/gomock" ) -func (e streamError) TestError() error { - nums := make([]interface{}, len(e.nums)) - for i, num := range e.nums { - nums[i] = num - } - return fmt.Errorf(e.message, nums...) -} - func TestStreamsMapCreatingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient,