diff --git a/receive_stream.go b/receive_stream.go index 59deb744..de76335e 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -78,7 +78,10 @@ func (s *receiveStream) StreamID() protocol.StreamID { // Read implements io.Reader. It is not thread safe! func (s *receiveStream) Read(p []byte) (int, error) { + s.mutex.Lock() completed, n, err := s.readImpl(p) + s.mutex.Unlock() + if completed { s.streamCompleted() } @@ -86,9 +89,6 @@ func (s *receiveStream) Read(p []byte) (int, error) { } func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.finRead { return false, 0, io.EOF } @@ -192,14 +192,18 @@ func (s *receiveStream) dequeueNextFrame() { func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) { s.mutex.Lock() - defer s.mutex.Unlock() + completed := s.cancelReadImpl(errorCode) + s.mutex.Unlock() - if s.finRead || s.canceledRead || s.resetRemotely { - return - } - if s.finalOffset != protocol.MaxByteCount { // final offset was already received + if completed { s.streamCompleted() } +} + +func (s *receiveStream) cancelReadImpl(errorCode protocol.ApplicationErrorCode) bool /* completed */ { + if s.finRead || s.canceledRead || s.resetRemotely { + return false + } s.canceledRead = true s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) s.signalRead() @@ -207,38 +211,44 @@ func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) { StreamID: s.streamID, ErrorCode: errorCode, }) + // We're done with this stream if the final offset was already received. + return s.finalOffset != protocol.MaxByteCount } func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { - maxOffset := frame.Offset + frame.DataLen() - s.mutex.Lock() - defer s.mutex.Unlock() + completed, err := s.handleStreamFrameImpl(frame) + s.mutex.Unlock() + if completed { + s.streamCompleted() + } + return err +} + +func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) { + maxOffset := frame.Offset + frame.DataLen() if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { - return err + return false, err } if frame.FinBit { s.finalOffset = maxOffset } if s.canceledRead { - if frame.FinBit { - s.streamCompleted() - } - return nil + return frame.FinBit, nil } if err := s.frameQueue.Push(frame.Data, frame.Offset); err != nil { - return err - } - if frame.FinBit { - s.finalOffset = maxOffset + return false, err } s.signalRead() - return nil + return false, nil } func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { + s.mutex.Lock() completed, err := s.handleResetStreamFrameImpl(frame) + s.mutex.Unlock() + if completed { s.streamCompleted() } @@ -246,9 +256,6 @@ func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) err } func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.closedForShutdown { return false, nil } @@ -298,7 +305,11 @@ func (s *receiveStream) getWindowUpdate() protocol.ByteCount { } func (s *receiveStream) streamCompleted() { - if !s.finRead { + s.mutex.Lock() + finRead := s.finRead + s.mutex.Unlock() + + if !finRead { s.flowController.Abandon() } s.sender.onStreamCompleted(s.streamID) diff --git a/send_stream.go b/send_stream.go index 056fcf56..92387c86 100644 --- a/send_stream.go +++ b/send_stream.go @@ -146,7 +146,10 @@ func (s *sendStream) Write(p []byte) (int, error) { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + s.mutex.Lock() completed, frame, hasMoreData := s.popStreamFrameImpl(maxBytes) + s.mutex.Unlock() + if completed { s.sender.onStreamCompleted(s.streamID) } @@ -154,9 +157,6 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr } func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* completed */, *wire.StreamFrame, bool /* has more data to send */) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.canceledWrite || s.closeForShutdownErr != nil { return false, nil, false } @@ -273,6 +273,7 @@ func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { s.mutex.Lock() hasStreamData := s.dataForWriting != nil s.mutex.Unlock() + s.flowController.UpdateSendWindow(frame.ByteOffset) if hasStreamData { s.sender.onHasStreamData(s.streamID) @@ -280,16 +281,17 @@ func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - if completed := s.handleStopSendingFrameImpl(frame); completed { + s.mutex.Lock() + completed := s.handleStopSendingFrameImpl(frame) + s.mutex.Unlock() + + if completed { s.sender.onStreamCompleted(s.streamID) } } // must be called after locking the mutex func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) bool /*completed*/ { - s.mutex.Lock() - defer s.mutex.Unlock() - writeErr := streamCanceledError{ errorCode: frame.ErrorCode, error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),