handle stream creation in streamsMap, remove streams mutex from session

This commit is contained in:
Lucas Clemente
2016-08-08 15:25:22 +02:00
parent 65663c3314
commit a1e2977f50
3 changed files with 53 additions and 119 deletions

View File

@@ -3,7 +3,6 @@ package quic
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
@@ -48,9 +47,7 @@ type Session struct {
conn connection
streamsMap *streamsMap
openStreamsCount uint32
streamsMutex sync.RWMutex
streamsMap *streamsMap
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
@@ -129,7 +126,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
session.streamsMap = newStreamsMap(session.newStream)
cryptoStream, _ := session.OpenStream(1)
cryptoStream, _ := session.GetOrOpenStream(1)
var err error
session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
if err != nil {
@@ -332,20 +329,9 @@ func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data [
}
func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
str, strExists := s.streamsMap.GetStream(frame.StreamID)
var err error
if !strExists {
if !s.isValidStreamID(frame.StreamID) {
return qerr.InvalidStreamID
}
str, err = s.newStreamImpl(frame.StreamID)
if err != nil {
return err
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// Stream is closed, ignore
@@ -355,29 +341,17 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error {
if err != nil {
return err
}
if !strExists {
s.streamCallback(s, str)
}
return nil
}
func (s *Session) isValidStreamID(streamID protocol.StreamID) bool {
return streamID%2 == 1
}
func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error {
s.streamsMutex.RLock()
defer s.streamsMutex.RUnlock()
if frame.StreamID != 0 {
str, strExists := s.streamsMap.GetStream(frame.StreamID)
if strExists && str == nil {
return errWindowUpdateOnClosedStream
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if err != nil {
return err
}
// open new stream when receiving a WindowUpdate for a non-existing stream
// this can occur if the client immediately sends a WindowUpdate for a newly opened stream, and packet reordering occurs such that the packet opening the new stream arrives after the WindowUpdate
if !strExists {
s.newStreamImpl(frame.StreamID)
if str == nil {
return errWindowUpdateOnClosedStream
}
}
_, err := s.flowControlManager.UpdateWindow(frame.StreamID, frame.ByteOffset)
@@ -386,10 +360,11 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error
// TODO: Handle frame.byteOffset
func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
s.streamsMutex.RLock()
str, streamExists := s.streamsMap.GetStream(frame.StreamID)
s.streamsMutex.RUnlock()
if !streamExists || str == nil {
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
return errRstStreamOnInvalidStream
}
s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode))
@@ -446,15 +421,10 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
}
func (s *Session) closeStreamsWithError(err error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
fn := func(str *stream) (bool, error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
s.closeStreamWithError(str, err)
return true, nil
}
s.streamsMap.Iterate(fn)
})
}
func (s *Session) closeStreamWithError(str *stream, err error) {
@@ -588,35 +558,17 @@ func (s *Session) logPacket(packet *packedPacket) {
}
}
// OpenStream creates a new stream open for reading and writing
func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
return s.newStreamImpl(id)
}
// GetOrOpenStream returns an existing stream with the given id, or opens a new stream
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) {
s.streamsMutex.Lock()
defer s.streamsMutex.Unlock()
stream, strExists := s.streamsMap.GetStream(id)
if strExists {
return stream, nil
}
return s.newStreamImpl(id)
return s.streamsMap.GetOrOpenStream(id)
}
// The streamsMutex is locked by OpenStream or GetOrOpenStream before calling this function.
func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
maxAllowedStreams := uint32(protocol.MaxStreamsMultiplier * float32(s.connectionParametersManager.GetMaxStreamsPerConnection()))
if atomic.LoadUint32(&s.openStreamsCount) >= maxAllowedStreams {
go s.Close(qerr.TooManyOpenStreams)
return nil, qerr.TooManyOpenStreams
}
_, strExists := s.streamsMap.GetStream(id)
if strExists {
return nil, fmt.Errorf("Session: stream with ID %d already exists", id)
}
return s.streamsMap.GetOrOpenStream(id)
}
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
stream, err := newStream(id, s.scheduleSending, s.flowControlManager)
if err != nil {
return nil, err
@@ -629,25 +581,17 @@ func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) {
s.flowControlManager.NewStream(id, true)
}
atomic.AddUint32(&s.openStreamsCount, 1)
err = s.streamsMap.PutStream(stream)
if err != nil {
return nil, err
}
return stream, nil
}
s.streamCallback(s, stream)
func (s *Session) newStream(id protocol.StreamID) (*stream, error) {
return nil, errors.New("not implemented")
return stream, nil
}
// garbageCollectStreams goes through all streams and removes EOF'ed streams
// from the streams map.
func (s *Session) garbageCollectStreams() {
fn := func(str *stream) (bool, error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
id := str.StreamID()
if str.finished() {
atomic.AddUint32(&s.openStreamsCount, ^uint32(0)) // decrement
err := s.streamsMap.RemoveStream(id)
if err != nil {
return false, err
@@ -655,9 +599,7 @@ func (s *Session) garbageCollectStreams() {
s.flowControlManager.RemoveStream(id)
}
return true, nil
}
s.streamsMap.Iterate(fn)
})
}
func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
@@ -689,12 +631,9 @@ func (s *Session) tryDecryptingQueuedPackets() {
}
func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
s.streamsMutex.RLock()
defer s.streamsMutex.RUnlock()
var res []*frames.WindowUpdateFrame
fn := func(str *stream) (bool, error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
id := str.StreamID()
doUpdate, offset, err := s.flowControlManager.MaybeTriggerStreamWindowUpdate(id)
if err != nil {
@@ -704,9 +643,7 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
res = append(res, &frames.WindowUpdateFrame{StreamID: id, ByteOffset: offset})
}
return true, nil
}
s.streamsMap.Iterate(fn)
})
doUpdate, offset := s.flowControlManager.MaybeTriggerConnectionWindowUpdate()
if doUpdate {