package quic import ( "errors" "fmt" "sync" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) type streamsMap struct { mutex sync.RWMutex perspective protocol.Perspective connectionParameters handshake.ConnectionParametersManager streams map[protocol.StreamID]*stream // needed for round-robin scheduling openStreams []protocol.StreamID roundRobinIndex uint32 nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID nextStreamOrErrCond sync.Cond openStreamOrErrCond sync.Cond closeErr error nextStreamToAccept protocol.StreamID newStream newStreamLambda maxOutgoingStreams uint32 numOutgoingStreams uint32 maxIncomingStreams uint32 numIncomingStreams uint32 } type streamLambda func(*stream) (bool, error) type newStreamLambda func(protocol.StreamID) (*stream, error) var ( errMapAccess = errors.New("streamsMap: Error accessing the streams map") ) func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { sm := streamsMap{ perspective: pers, streams: map[protocol.StreamID]*stream{}, openStreams: make([]protocol.StreamID, 0), newStream: newStream, connectionParameters: connectionParameters, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex if pers == protocol.PerspectiveClient { sm.nextStream = 1 sm.nextStreamToAccept = 2 } else { sm.nextStream = 2 sm.nextStreamToAccept = 1 } return &sm } // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { m.mutex.RLock() s, ok := m.streams[id] m.mutex.RUnlock() if ok { return s, nil // s may be nil } // ... we don't have an existing stream m.mutex.Lock() defer m.mutex.Unlock() // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). s, ok = m.streams[id] if ok { return s, nil } if id <= m.highestStreamOpenedByPeer { return nil, nil } if m.perspective == protocol.PerspectiveServer && id%2 == 0 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) } if m.perspective == protocol.PerspectiveClient && id%2 == 1 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) } // sid is the next stream that will be opened sid := m.highestStreamOpenedByPeer + 2 // if there is no stream opened yet, and this is the server, stream 1 should be openend if sid == 2 && m.perspective == protocol.PerspectiveServer { sid = 1 } for ; sid <= id; sid += 2 { _, err := m.openRemoteStream(sid) if err != nil { return nil, err } } m.nextStreamOrErrCond.Broadcast() return m.streams[id], nil } func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) } s, err := m.newStream(id) if err != nil { return nil, err } if m.perspective == protocol.PerspectiveServer { m.numIncomingStreams++ } else { m.numOutgoingStreams++ } if id > m.highestStreamOpenedByPeer { m.highestStreamOpenedByPeer = id } m.putStream(s) return s, nil } func (m *streamsMap) openStreamImpl() (*stream, error) { id := m.nextStream if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { return nil, qerr.TooManyOpenStreams } s, err := m.newStream(id) if err != nil { return nil, err } if m.perspective == protocol.PerspectiveServer { m.numOutgoingStreams++ } else { m.numIncomingStreams++ } m.nextStream += 2 m.putStream(s) return s, nil } // OpenStream opens the next available stream func (m *streamsMap) OpenStream() (*stream, error) { m.mutex.Lock() defer m.mutex.Unlock() return m.openStreamImpl() } func (m *streamsMap) OpenStreamSync() (*stream, error) { m.mutex.Lock() defer m.mutex.Unlock() for { if m.closeErr != nil { return nil, m.closeErr } str, err := m.openStreamImpl() if err == nil { return str, err } if err != nil && err != qerr.TooManyOpenStreams { return nil, err } m.openStreamOrErrCond.Wait() } } // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened func (m *streamsMap) AcceptStream() (*stream, error) { m.mutex.Lock() defer m.mutex.Unlock() var str *stream for { var ok bool if m.closeErr != nil { return nil, m.closeErr } str, ok = m.streams[m.nextStreamToAccept] if ok { break } m.nextStreamOrErrCond.Wait() } m.nextStreamToAccept += 2 return str, nil } func (m *streamsMap) Iterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() openStreams := make([]protocol.StreamID, len(m.openStreams), len(m.openStreams)) for i, streamID := range m.openStreams { // copy openStreams openStreams[i] = streamID } for _, streamID := range openStreams { cont, err := m.iterateFunc(streamID, fn) if err != nil { return err } if !cont { break } } return nil } // RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false // It uses a round-robin-like scheduling to ensure that every stream is considered fairly // It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() numStreams := uint32(len(m.streams)) startIndex := m.roundRobinIndex for _, i := range []protocol.StreamID{1, 3} { cont, err := m.iterateFunc(i, fn) if err != nil && err != errMapAccess { return err } if !cont { return nil } } for i := uint32(0); i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] if streamID == 1 || streamID == 3 { continue } cont, err := m.iterateFunc(streamID, fn) if err != nil { return err } m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams if !cont { break } } return nil } func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { str, ok := m.streams[streamID] if !ok { return true, errMapAccess } return fn(str) } func (m *streamsMap) putStream(s *stream) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) } m.streams[id] = s m.openStreams = append(m.openStreams, id) return nil } // Attention: this function must only be called if a mutex has been acquired previously func (m *streamsMap) RemoveStream(id protocol.StreamID) error { s, ok := m.streams[id] if !ok || s == nil { return fmt.Errorf("attempted to remove non-existing stream: %d", id) } if id%2 == 0 { m.numOutgoingStreams-- } else { m.numIncomingStreams-- } for i, s := range m.openStreams { if s == id { // delete the streamID from the openStreams slice m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] // adjust round-robin index, if necessary if uint32(i) < m.roundRobinIndex { m.roundRobinIndex-- } break } } delete(m.streams, id) m.openStreamOrErrCond.Signal() return nil } func (m *streamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err m.nextStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast() m.mutex.Unlock() }