diff --git a/session.go b/session.go index 24137873..d839f824 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package quic import ( "errors" "fmt" + "sync" "sync/atomic" "time" @@ -49,6 +50,7 @@ type Session struct { streamsMap *streamsMap openStreamsCount uint32 + streamsMutex sync.RWMutex sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler @@ -330,6 +332,8 @@ func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data [ } func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() str, strExists := s.streamsMap.GetStream(frame.StreamID) var err error @@ -362,6 +366,8 @@ func (s *Session) isValidStreamID(streamID protocol.StreamID) bool { } func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { + s.streamsMutex.RLock() + defer s.streamsMutex.RUnlock() if frame.StreamID != 0 { str, strExists := s.streamsMap.GetStream(frame.StreamID) if strExists && str == nil { @@ -380,7 +386,9 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error // TODO: Handle frame.byteOffset func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { + s.streamsMutex.RLock() str, streamExists := s.streamsMap.GetStream(frame.StreamID) + s.streamsMutex.RUnlock() if !streamExists || str == nil { return errRstStreamOnInvalidStream } @@ -438,6 +446,9 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { } func (s *Session) closeStreamsWithError(err error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() + fn := func(str *stream) (bool, error) { s.closeStreamWithError(str, err) return true, nil @@ -579,11 +590,15 @@ func (s *Session) logPacket(packet *packedPacket) { // OpenStream creates a new stream open for reading and writing func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() return s.newStreamImpl(id) } // GetOrOpenStream returns an existing stream with the given id, or opens a new stream func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() stream, strExists := s.streamsMap.GetStream(id) if strExists { return stream, nil @@ -591,6 +606,7 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { return s.newStreamImpl(id) } +// The streamsMutex is locked by OpenStream or GetOrOpenStream before calling this function. func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { maxAllowedStreams := uint32(protocol.MaxStreamsMultiplier * float32(s.connectionParametersManager.GetMaxStreamsPerConnection())) if atomic.LoadUint32(&s.openStreamsCount) >= maxAllowedStreams { @@ -669,6 +685,9 @@ func (s *Session) tryDecryptingQueuedPackets() { } func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { + s.streamsMutex.RLock() + defer s.streamsMutex.RUnlock() + var res []*frames.WindowUpdateFrame fn := func(str *stream) (bool, error) {