diff --git a/session.go b/session.go index 100725aa3..833a7d0f7 100644 --- a/session.go +++ b/session.go @@ -78,6 +78,8 @@ func (s *Session) Run() { if err != nil { fmt.Printf("Error in session: %s\n", err.Error()) } + + s.garbageCollectStreams() } } @@ -154,7 +156,7 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return nil } -// Close closes the connection by sending a ConnectionClose frame +// Close the connection by sending a ConnectionClose frame func (s *Session) Close(e error) error { errorCode := protocol.ErrorCode(1) reasonPhrase := e.Error() @@ -168,6 +170,17 @@ func (s *Session) Close(e error) error { }) } +func (s *Session) closeStreamsWithError(err error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() + for _, s := range s.streams { + if s == nil { + continue + } + s.RegisterError(err) + } +} + func (s *Session) sendPacket() error { packet, err := s.packer.PackPacket() if err != nil { @@ -207,10 +220,16 @@ func (s *Session) NewStream(id protocol.StreamID) (utils.Stream, error) { return stream, nil } -// closeStream is called by a stream to signal that it was closed remotely -// and has fininshed reading its data. -func (s *Session) closeStream(id protocol.StreamID) { +// garbageCollectStreams goes through all streams and removes EOF'ed streams +// from the streams map. +func (s *Session) garbageCollectStreams() { s.streamsMutex.Lock() - s.streams[id] = nil - s.streamsMutex.Unlock() + defer s.streamsMutex.Unlock() + for k, v := range s.streams { + // Strictly speaking, this is not thread-safe. However it doesn't matter + // if the stream is deleted just shortly later, so we don't care. + if v.finishedReading() { + s.streams[k] = nil + } + } } diff --git a/session_test.go b/session_test.go index de0e73691..30def264f 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "io" . "github.com/onsi/ginkgo" @@ -58,7 +59,7 @@ var _ = Describe("Session", func() { Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) - It("closes streams", func() { + It("closes streams with FIN bits", func() { session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, @@ -71,6 +72,43 @@ var _ = Describe("Session", func() { _, err := session.streams[5].Read(p) Expect(err).To(Equal(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + session.garbageCollectStreams() + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).To(BeNil()) + }) + + It("closes streams with error", func() { + testErr := errors.New("test") + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + }) + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + Expect(callbackCalled).To(BeTrue()) + p := make([]byte, 4) + _, err := session.streams[5].Read(p) + Expect(err).ToNot(HaveOccurred()) + session.closeStreamsWithError(testErr) + _, err = session.streams[5].Read(p) + Expect(err).To(Equal(testErr)) + session.garbageCollectStreams() + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).To(BeNil()) + }) + + It("closes empty streams with error", func() { + testErr := errors.New("test") + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + }) + Expect(session.streams).To(HaveLen(1)) + Expect(session.streams[5]).ToNot(BeNil()) + Expect(callbackCalled).To(BeTrue()) + session.closeStreamsWithError(testErr) + _, err := session.streams[5].Read([]byte{0}) + Expect(err).To(Equal(testErr)) + session.garbageCollectStreams() Expect(session.streams).To(HaveLen(1)) Expect(session.streams[5]).To(BeNil()) }) @@ -83,6 +121,7 @@ var _ = Describe("Session", func() { }) _, err := session.streams[5].Read([]byte{0}) Expect(err).To(Equal(io.EOF)) + session.garbageCollectStreams() err = session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{}, diff --git a/stream.go b/stream.go index 035745b3f..ff0497057 100644 --- a/stream.go +++ b/stream.go @@ -11,19 +11,21 @@ import ( type streamHandler interface { QueueFrame(frames.Frame) error - closeStream(protocol.StreamID) } // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface type stream struct { - session streamHandler - streamID protocol.StreamID + session streamHandler + streamID protocol.StreamID + // The chan of unordered stream frames. A nil in this channel is sent by the + // session if an error occured, in this case, remoteErr is filled before. streamFrames chan *frames.StreamFrame currentFrame *frames.StreamFrame readPosInFrame int writeOffset uint64 readOffset uint64 frameQueue []*frames.StreamFrame // TODO: replace with heap + remoteErr error currentErr error } @@ -36,14 +38,18 @@ func newStream(session streamHandler, StreamID protocol.StreamID) *stream { } } -// Read reads data +// Read implements io.Reader func (s *stream) Read(p []byte) (int, error) { + if s.currentErr != nil { + return 0, s.currentErr + } bytesRead := 0 for bytesRead < len(p) { if s.currentFrame == nil { var err error s.currentFrame, err = s.getNextFrameInOrder(bytesRead == 0) if err != nil { + s.currentErr = err return bytesRead, err } if s.currentFrame == nil { @@ -57,14 +63,12 @@ func (s *stream) Read(p []byte) (int, error) { bytesRead += m s.readOffset += uint64(m) if s.readPosInFrame >= len(s.currentFrame.Data) { - if s.currentFrame.FinBit { + fin := s.currentFrame.FinBit + s.currentFrame = nil + if fin { s.currentErr = io.EOF - close(s.streamFrames) - s.currentFrame = nil - s.session.closeStream(s.streamID) return bytesRead, io.EOF } - s.currentFrame = nil } } @@ -105,22 +109,26 @@ func (s *stream) getNextFrameInOrder(wait bool) (*frames.StreamFrame, error) { } } -func (s *stream) nextFrameInChan(blocking bool) (f *frames.StreamFrame, err error) { +func (s *stream) nextFrameInChan(blocking bool) (*frames.StreamFrame, error) { + var f *frames.StreamFrame var ok bool if blocking { - select { - case f, ok = <-s.streamFrames: - } + f, ok = <-s.streamFrames } else { select { case f, ok = <-s.streamFrames: default: + return nil, nil } } if !ok { - return nil, s.currentErr + panic("Stream: internal inconsistency: encountered closed chan without nil value (remote error) or FIN bit") } - return + if f == nil { + // We read nil, which indicates a remoteErr + return nil, s.remoteErr + } + return f, nil } // ReadByte implements io.ByteReader @@ -165,7 +173,10 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { // RegisterError is called by session to indicate that an error occured and the // stream should be closed. func (s *stream) RegisterError(err error) { - s.currentErr = err - s.session.closeStream(s.streamID) - close(s.streamFrames) + s.remoteErr = err + s.streamFrames <- nil +} + +func (s *stream) finishedReading() bool { + return s.currentErr != nil } diff --git a/stream_test.go b/stream_test.go index 37819343a..094eccfc6 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,14 +6,12 @@ import ( "time" "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type mockStreamHandler struct { - closedStream bool - frames []frames.Frame + frames []frames.Frame } func (m *mockStreamHandler) QueueFrame(f frames.Frame) error { @@ -21,10 +19,6 @@ func (m *mockStreamHandler) QueueFrame(f frames.Frame) error { return nil } -func (m *mockStreamHandler) closeStream(protocol.StreamID) { - m.closedStream = true -} - var _ = Describe("Stream", func() { var ( str *stream @@ -334,11 +328,6 @@ var _ = Describe("Stream", func() { }) Context("closing", func() { - AfterEach(func() { - Expect(str.streamFrames).To(BeClosed()) - Expect(handler.closedStream).To(BeTrue()) - }) - Context("with fin bit", func() { It("returns EOFs", func() { frame := frames.StreamFrame{ @@ -410,6 +399,24 @@ var _ = Describe("Stream", func() { Context("with remote errors", func() { testErr := errors.New("test error") + It("returns EOF if data is read before", func() { + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + FinBit: true, + } + str.AddStreamFrame(&frame) + str.RegisterError(testErr) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).To(Equal(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(Equal(io.EOF)) + }) + It("returns errors", func() { frame := frames.StreamFrame{ Offset: 0,