From 576aa1747f01c2805539dcf4045769458ef1140b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 8 Jun 2025 18:25:36 +0800 Subject: [PATCH] simplify stream ID handling in the incoming streams map (#5207) * simplify stream ID handling in the incoming streams map No functional change expected. * protocol: move stream ID constant definition --- internal/protocol/stream.go | 26 +++++++ streams_map.go | 18 ++--- streams_map_incoming.go | 98 ++++++++++++----------- streams_map_incoming_test.go | 147 ++++++++++++++++++++++------------- streams_map_outgoing_test.go | 17 ++++ streams_map_test.go | 46 ++++------- 6 files changed, 214 insertions(+), 138 deletions(-) diff --git a/internal/protocol/stream.go b/internal/protocol/stream.go index ad7de864b..6db4a95b9 100644 --- a/internal/protocol/stream.go +++ b/internal/protocol/stream.go @@ -1,5 +1,7 @@ package protocol +import "github.com/quic-go/quic-go/quicvarint" + // StreamType encodes if this is a unidirectional or bidirectional stream type StreamType uint8 @@ -23,6 +25,30 @@ const ( // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames // and as the stream count in the transport parameters MaxStreamCount StreamNum = 1 << 60 + // MaxStreamID is the maximum stream ID + MaxStreamID StreamID = quicvarint.Max +) + +const ( + // FirstOutgoingBidiStreamClient is the first bidirectional stream opened by the client + FirstOutgoingBidiStreamClient StreamID = 0 + // FirstOutgoingUniStreamClient is the first unidirectional stream opened by the client + FirstOutgoingUniStreamClient StreamID = 2 + // FirstOutgoingBidiStreamServer is the first bidirectional stream opened by the server + FirstOutgoingBidiStreamServer StreamID = 1 + // FirstOutgoingUniStreamServer is the first unidirectional stream opened by the server + FirstOutgoingUniStreamServer StreamID = 3 +) + +const ( + // FirstIncomingBidiStreamServer is the first bidirectional stream accepted by the server + FirstIncomingBidiStreamServer = FirstOutgoingBidiStreamClient + // FirstIncomingUniStreamServer is the first unidirectional stream accepted by the server + FirstIncomingUniStreamServer = FirstOutgoingUniStreamClient + // FirstIncomingBidiStreamClient is the first bidirectional stream accepted by the client + FirstIncomingBidiStreamClient = FirstOutgoingBidiStreamServer + // FirstIncomingUniStreamClient is the first unidirectional stream accepted by the client + FirstIncomingUniStreamClient = FirstOutgoingUniStreamServer ) // StreamID calculates the stream ID. diff --git a/streams_map.go b/streams_map.go index dfc0e901a..53bf7f3ff 100644 --- a/streams_map.go +++ b/streams_map.go @@ -93,12 +93,12 @@ func (m *streamsMap) initMaps() { ) m.incomingBidiStreams = newIncomingStreamsMap( protocol.StreamTypeBidi, - func(num protocol.StreamNum) *Stream { - id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) + func(id protocol.StreamID) *Stream { return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.maxIncomingBidiStreams, m.queueControlFrame, + m.perspective, ) m.outgoingUniStreams = newOutgoingStreamsMap( protocol.StreamTypeUni, @@ -110,12 +110,12 @@ func (m *streamsMap) initMaps() { ) m.incomingUniStreams = newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *ReceiveStream { - id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) + func(id protocol.StreamID) *ReceiveStream { return newReceiveStream(id, m.sender, m.newFlowController(id)) }, m.maxIncomingUniStreams, m.queueControlFrame, + m.perspective, ) } @@ -198,12 +198,12 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error { if id.InitiatedBy() == m.perspective { return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) } - return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) + return m.incomingUniStreams.DeleteStream(id) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) } - return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) + return m.incomingBidiStreams.DeleteStream(id) } panic("") } @@ -227,7 +227,7 @@ func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStrea // an outgoing unidirectional stream is a send stream, not a receive stream return nil, fmt.Errorf("peer attempted to open receive stream %d", id) } - str, err := m.incomingUniStreams.GetOrOpenStream(num) + str, err := m.incomingUniStreams.GetOrOpenStream(id) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { @@ -237,7 +237,7 @@ func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (*ReceiveStrea } return str.ReceiveStream, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) } else { - str, err := m.incomingBidiStreams.GetOrOpenStream(num) + str, err := m.incomingBidiStreams.GetOrOpenStream(id) if str == nil && err == nil { return nil, nil } @@ -282,7 +282,7 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, err } return str.SendStream, nil } else { - str, err := m.incomingBidiStreams.GetOrOpenStream(num) + str, err := m.incomingBidiStreams.GetOrOpenStream(id) if str == nil && err == nil { return nil, nil } diff --git a/streams_map_incoming.go b/streams_map_incoming.go index 18ec6f998..28637798e 100644 --- a/streams_map_incoming.go +++ b/streams_map_incoming.go @@ -2,6 +2,7 @@ package quic import ( "context" + "fmt" "sync" "github.com/quic-go/quic-go/internal/protocol" @@ -24,14 +25,14 @@ type incomingStreamsMap[T incomingStream] struct { newStreamChan chan struct{} streamType protocol.StreamType - streams map[protocol.StreamNum]incomingStreamEntry[T] + streams map[protocol.StreamID]incomingStreamEntry[T] - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams + nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamID // the highest stream that the peer opened + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams - newStream func(protocol.StreamNum) T + newStream func(protocol.StreamID) T queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error @@ -39,19 +40,31 @@ type incomingStreamsMap[T incomingStream] struct { func newIncomingStreamsMap[T incomingStream]( streamType protocol.StreamType, - newStream func(protocol.StreamNum) T, + newStream func(protocol.StreamID) T, maxStreams uint64, queueControlFrame func(wire.Frame), + pers protocol.Perspective, ) *incomingStreamsMap[T] { + var nextStreamToAccept protocol.StreamID + switch { + case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer: + nextStreamToAccept = protocol.FirstIncomingBidiStreamServer + case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient: + nextStreamToAccept = protocol.FirstIncomingBidiStreamClient + case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer: + nextStreamToAccept = protocol.FirstIncomingUniStreamServer + case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient: + nextStreamToAccept = protocol.FirstIncomingUniStreamClient + } return &incomingStreamsMap[T]{ newStreamChan: make(chan struct{}, 1), streamType: streamType, - streams: make(map[protocol.StreamNum]incomingStreamEntry[T]), - maxStream: protocol.StreamNum(maxStreams), + streams: make(map[protocol.StreamID]incomingStreamEntry[T]), + maxStream: protocol.StreamNum(maxStreams).StreamID(streamType, pers.Opposite()), maxNumStreams: maxStreams, newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, + nextStreamToOpen: nextStreamToAccept, + nextStreamToAccept: nextStreamToAccept, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, } } @@ -65,16 +78,16 @@ func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) { m.mutex.Lock() - var num protocol.StreamNum + var id protocol.StreamID var entry incomingStreamEntry[T] for { - num = m.nextStreamToAccept + id = m.nextStreamToAccept if m.closeErr != nil { m.mutex.Unlock() return *new(T), m.closeErr } var ok bool - entry, ok = m.streams[num] + entry, ok = m.streams[id] if ok { break } @@ -86,10 +99,10 @@ func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) { } m.mutex.Lock() } - m.nextStreamToAccept++ + m.nextStreamToAccept += 4 // If this stream was completed before being accepted, we can delete it now. if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { + if err := m.deleteStream(id); err != nil { m.mutex.Unlock() return *new(T), err } @@ -98,22 +111,19 @@ func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) { return entry.stream, nil } -func (m *incomingStreamsMap[T]) GetOrOpenStream(num protocol.StreamNum) (T, error) { +func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error) { m.mutex.RLock() - if num > m.maxStream { + if id > m.maxStream { m.mutex.RUnlock() - return *new(T), streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } + return *new(T), fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) } // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { + if id < m.nextStreamToOpen { var s T // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + if entry, ok := m.streams[id]; ok && !entry.shouldDelete { s = entry.stream } m.mutex.RUnlock() @@ -125,59 +135,53 @@ func (m *incomingStreamsMap[T]) GetOrOpenStream(num protocol.StreamNum) (T, erro // no need to check the two error conditions from above again // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + for newNum := m.nextStreamToOpen; newNum <= id; newNum += 4 { m.streams[newNum] = incomingStreamEntry[T]{stream: m.newStream(newNum)} select { case m.newStreamChan <- struct{}{}: default: } } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] + m.nextStreamToOpen = id + 4 + entry := m.streams[id] m.mutex.Unlock() return entry.stream, nil } -func (m *incomingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { +func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.deleteStream(num) + return m.deleteStream(id) } -func (m *incomingStreamsMap[T]) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } +func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error { + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("tried to delete unknown incoming stream %d", id) } // Don't delete this stream yet, if it was not yet accepted. // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] + if id >= m.nextStreamToAccept { + entry, ok := m.streams[id] if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } + return fmt.Errorf("tried to delete incoming stream %d multiple times", id) } entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign + m.streams[id] = entry // can't assign to struct in map, so we need to reassign return nil } - delete(m.streams, num) + delete(m.streams, id) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { + maxStream := m.nextStreamToOpen + 4*protocol.StreamID(m.maxNumStreams-uint64(len(m.streams))-1) + // never send a value larger than the maximum value for a stream number + if maxStream <= protocol.MaxStreamID { m.maxStream = maxStream m.queueMaxStreamID(&wire.MaxStreamsFrame{ Type: m.streamType, - MaxStreamNum: m.maxStream, + MaxStreamNum: m.maxStream.StreamNum(), }) } } diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index 5ec2a5230..c62e9ca64 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -13,77 +13,88 @@ import ( "github.com/stretchr/testify/require" ) -type mockGenericStream struct { - num protocol.StreamNum +type mockStream struct { + id protocol.StreamID closed bool closeErr error sendWindow protocol.ByteCount } -func (s *mockGenericStream) closeForShutdown(err error) { +func (s *mockStream) closeForShutdown(err error) { s.closed = true s.closeErr = err } -func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { +func (s *mockStream) updateSendWindow(limit protocol.ByteCount) { s.sendWindow = limit } func TestStreamsMapIncomingGettingStreams(t *testing.T) { - var newItemCounter int + t.Run("client", func(t *testing.T) { + testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) + }) + t.Run("server", func(t *testing.T) { + testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) + }) +} + +func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { + var newStreamCounter int const maxNumStreams = 10 m := newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { - newItemCounter++ - return &mockGenericStream{num: num} + func(id protocol.StreamID) *mockStream { + newStreamCounter++ + return &mockStream{id: id} }, maxNumStreams, func(f wire.Frame) {}, + perspective, ) // all streams up to the id on GetOrOpenStream are opened - str, err := m.GetOrOpenStream(2) + str, err := m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) - require.Equal(t, 2, newItemCounter) - require.Equal(t, protocol.StreamNum(2), str.num) + require.NotNil(t, str) + require.Equal(t, 2, newStreamCounter) + require.Equal(t, firstStream+4, str.id) // accept one of the streams ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err = m.AcceptStream(ctx) require.NoError(t, err) - require.Equal(t, protocol.StreamNum(1), str.num) + require.Equal(t, firstStream, str.id) // open some more streams - str, err = m.GetOrOpenStream(5) + str, err = m.GetOrOpenStream(firstStream + 16) require.NoError(t, err) - require.Equal(t, 5, newItemCounter) - require.Equal(t, protocol.StreamNum(5), str.num) + require.Equal(t, 5, newStreamCounter) + require.Equal(t, firstStream+16, str.id) // and accept all of them - for i := 2; i <= 5; i++ { + for i := 1; i < 5; i++ { str, err := m.AcceptStream(ctx) require.NoError(t, err) - require.Equal(t, protocol.StreamNum(i), str.num) + require.Equal(t, firstStream+4*protocol.StreamID(i), str.id) } - _, err = m.GetOrOpenStream(maxNumStreams) + _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4) require.NoError(t, err) - _, err = m.GetOrOpenStream(maxNumStreams + 1) - require.Error(t, err) + _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams) require.ErrorContains(t, err, "peer tried to open stream") + require.Equal(t, maxNumStreams, newStreamCounter) } func TestStreamsMapIncomingAcceptingStreams(t *testing.T) { m := newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) {}, + protocol.PerspectiveClient, ) - errChan := make(chan error, 1) - // AcceptStream should respect the context + errChan := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond)) defer cancel() @@ -110,7 +121,7 @@ func TestStreamsMapIncomingAcceptingStreams(t *testing.T) { case <-time.After(scaleDuration(10 * time.Millisecond)): } - _, err := m.GetOrOpenStream(1) + _, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient) require.NoError(t, err) select { @@ -122,43 +133,52 @@ func TestStreamsMapIncomingAcceptingStreams(t *testing.T) { } func TestStreamsMapIncomingDeletingStreams(t *testing.T) { + t.Run("client", func(t *testing.T) { + testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) + }) + t.Run("server", func(t *testing.T) { + testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) + }) +} + +func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { var frameQueue []wire.Frame m := newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, + perspective, ) - err := m.DeleteStream(1337) - require.Error(t, err) - require.ErrorContains(t, err.(streamError).TestError(), "tried to delete unknown incoming stream 1337") + err := m.DeleteStream(firstStream + 1337*4) + require.ErrorContains(t, err, "tried to delete unknown incoming stream") - s, err := m.GetOrOpenStream(2) + s, err := m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) require.NotNil(t, s) // delete the stream - require.NoError(t, m.DeleteStream(2)) + require.NoError(t, m.DeleteStream(firstStream+4)) require.Empty(t, frameQueue) // it's not returned by GetOrOpenStream anymore - s, err = m.GetOrOpenStream(2) + s, err = m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) require.Nil(t, s) // AcceptStream still returns this stream str, err := m.AcceptStream(context.Background()) require.NoError(t, err) - require.Equal(t, protocol.StreamNum(1), str.num) + require.Equal(t, firstStream, str.id) require.Empty(t, frameQueue) str, err = m.AcceptStream(context.Background()) require.NoError(t, err) - require.Equal(t, protocol.StreamNum(2), str.num) + require.Equal(t, firstStream+4, str.id) // now the stream is deleted and new stream credit is issued require.Len(t, frameQueue, 1) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 6}, frameQueue[0]) frameQueue = frameQueue[:0] - require.NoError(t, m.DeleteStream(1)) + require.NoError(t, m.DeleteStream(firstStream)) require.Len(t, frameQueue, 1) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 7}, frameQueue[0]) } @@ -167,45 +187,62 @@ func TestStreamsMapIncomingDeletingStreams(t *testing.T) { // Since the stream limit is configurable by the user, we can't rely on this number // being high enough that it will never be reached in practice. func TestStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T) { + t.Run("client", func(t *testing.T) { + testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) + }) + t.Run("server", func(t *testing.T) { + testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) + }) +} + +func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers protocol.Perspective, firstStream protocol.StreamID) { var frameQueue []wire.Frame m := newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, uint64(protocol.MaxStreamCount-2), func(f wire.Frame) { frameQueue = append(frameQueue, f) }, + pers, ) // open a bunch of streams - _, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(firstStream + 16) require.NoError(t, err) // accept all streams - for i := 0; i < 5; i++ { + for range 5 { _, err := m.AcceptStream(context.Background()) require.NoError(t, err) } require.Empty(t, frameQueue) - require.NoError(t, m.DeleteStream(4)) + require.NoError(t, m.DeleteStream(firstStream+12)) require.Len(t, frameQueue, 1) - require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount - 1}, frameQueue[0]) - require.NoError(t, m.DeleteStream(3)) + require.Equal(t, + &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount - 1}, + frameQueue[0], + ) + require.NoError(t, m.DeleteStream(firstStream+8)) require.Len(t, frameQueue, 2) - require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}, frameQueue[1]) + require.Equal(t, + &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}, + frameQueue[1], + ) // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent - require.NoError(t, m.DeleteStream(2)) - require.NoError(t, m.DeleteStream(1)) + require.NoError(t, m.DeleteStream(firstStream+4)) + require.NoError(t, m.DeleteStream(firstStream)) require.Len(t, frameQueue, 2) } func TestStreamsMapIncomingClosing(t *testing.T) { m := newIncomingStreamsMap( protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} }, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) {}, + protocol.PerspectiveServer, ) - var streams []*mockGenericStream - _, err := m.GetOrOpenStream(3) + var streams []*mockStream + _, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8) require.NoError(t, err) for range 3 { str, err := m.AcceptStream(context.Background()) @@ -238,16 +275,23 @@ func TestStreamsMapIncomingClosing(t *testing.T) { func TestStreamsMapIncomingRandomized(t *testing.T) { const num = 1000 + streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] + firstStream := protocol.FirstIncomingUniStreamServer + if streamType == protocol.StreamTypeBidi { + firstStream = protocol.FirstIncomingBidiStreamServer + } + m := newIncomingStreamsMap( - protocol.StreamTypeUni, - func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} }, + streamType, + func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, num, func(f wire.Frame) {}, + protocol.PerspectiveServer, ) - ids := make([]protocol.StreamNum, num) + ids := make([]protocol.StreamID, num) for i := range num { - ids[i] = protocol.StreamNum(i + 1) + ids[i] = firstStream + 4*protocol.StreamID(i) } rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) @@ -268,8 +312,7 @@ func TestStreamsMapIncomingRandomized(t *testing.T) { errChan2 := make(chan error, 1) go func() { for i := range num { - _, err := m.GetOrOpenStream(ids[i]) - if err != nil { + if _, err := m.GetOrOpenStream(ids[i]); err != nil { errChan2 <- err return } diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index f47090176..6e157cd23 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -15,6 +15,23 @@ import ( "github.com/stretchr/testify/require" ) +type mockGenericStream struct { + num protocol.StreamNum + + closed bool + closeErr error + sendWindow protocol.ByteCount +} + +func (s *mockGenericStream) closeForShutdown(err error) { + s.closed = true + s.closeErr = err +} + +func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { + s.sendWindow = limit +} + func TestStreamsMapOutgoingOpenAndDelete(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeBidi, diff --git a/streams_map_test.go b/streams_map_test.go index 9710926c5..8298b0ee9 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -17,20 +17,6 @@ import ( "go.uber.org/mock/gomock" ) -const ( - firstIncomingBidiStreamServer protocol.StreamID = 0 - firstOutgoingBidiStreamServer protocol.StreamID = 1 - firstIncomingUniStreamServer protocol.StreamID = 2 - firstOutgoingUniStreamServer protocol.StreamID = 3 -) - -const ( - firstIncomingBidiStreamClient protocol.StreamID = 1 - firstOutgoingBidiStreamClient protocol.StreamID = 0 - firstIncomingUniStreamClient protocol.StreamID = 3 - firstOutgoingUniStreamClient protocol.StreamID = 2 -) - func (e streamError) TestError() error { nums := make([]interface{}, len(e.nums)) for i, num := range e.nums { @@ -42,18 +28,18 @@ func (e streamError) TestError() error { func TestStreamsMapCreatingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveClient, - firstIncomingBidiStreamClient, - firstOutgoingBidiStreamClient, - firstIncomingUniStreamClient, - firstOutgoingUniStreamClient, + protocol.FirstIncomingBidiStreamClient, + protocol.FirstOutgoingBidiStreamClient, + protocol.FirstIncomingUniStreamClient, + protocol.FirstOutgoingUniStreamClient, ) }) t.Run("server", func(t *testing.T) { testStreamsMapCreatingAndDeletingStreams(t, protocol.PerspectiveServer, - firstIncomingBidiStreamServer, - firstOutgoingBidiStreamServer, - firstIncomingUniStreamServer, - firstOutgoingUniStreamServer, + protocol.FirstIncomingBidiStreamServer, + protocol.FirstOutgoingBidiStreamServer, + protocol.FirstIncomingUniStreamServer, + protocol.FirstOutgoingUniStreamServer, ) }) } @@ -132,18 +118,18 @@ func testStreamsMapCreatingAndDeletingStreams(t *testing.T, func TestStreamsMapDeletingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapDeletingStreams(t, protocol.PerspectiveClient, - firstIncomingBidiStreamClient, - firstOutgoingBidiStreamClient, - firstIncomingUniStreamClient, - firstOutgoingUniStreamClient, + protocol.FirstIncomingBidiStreamClient, + protocol.FirstOutgoingBidiStreamClient, + protocol.FirstIncomingUniStreamClient, + protocol.FirstOutgoingUniStreamClient, ) }) t.Run("server", func(t *testing.T) { testStreamsMapDeletingStreams(t, protocol.PerspectiveServer, - firstIncomingBidiStreamServer, - firstOutgoingBidiStreamServer, - firstIncomingUniStreamServer, - firstOutgoingUniStreamServer, + protocol.FirstIncomingBidiStreamServer, + protocol.FirstOutgoingBidiStreamServer, + protocol.FirstIncomingUniStreamServer, + protocol.FirstOutgoingUniStreamServer, ) }) }