diff --git a/session.go b/session.go index 891c23b83..2902ecbb4 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/lucas-clemente/quic-go/ackhandler" @@ -29,7 +30,8 @@ type Session struct { Connection *net.UDPConn CurrentRemoteAddr *net.UDPAddr - Streams map[protocol.StreamID]*Stream + Streams map[protocol.StreamID]*Stream + streamsMutex sync.RWMutex outgoingAckHandler ackhandler.OutgoingPacketAckHandler incomingAckHandler ackhandler.IncomingPacketAckHandler @@ -131,16 +133,21 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { if frame.StreamID == 0 { return errors.New("Session: 0 is not a valid Stream ID") } - stream, newStream := s.Streams[frame.StreamID] + s.streamsMutex.RLock() + stream, existingStream := s.Streams[frame.StreamID] + s.streamsMutex.RUnlock() - if !newStream { + if !existingStream { stream, _ = s.NewStream(frame.StreamID) } + if stream == nil { + return errors.New("Session: reopening streams is not allowed") + } err := stream.AddStreamFrame(frame) if err != nil { return err } - if !newStream { + if !existingStream { s.streamCallback(s, stream) } return nil @@ -190,6 +197,8 @@ func (s *Session) QueueFrame(frame frames.Frame) error { // NewStream creates a new strean open for reading and writing func (s *Session) NewStream(id protocol.StreamID) (*Stream, error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() stream := NewStream(s, id) if s.Streams[id] != nil { return nil, fmt.Errorf("Session: stream with ID %d already exists", id) @@ -197,3 +206,11 @@ func (s *Session) NewStream(id protocol.StreamID) (*Stream, error) { s.Streams[id] = stream return stream, nil } + +// closeStream is called by a stream to signal that it was closed remotely +// and has fininshed reading its data. +func (s *Session) closeStream(id protocol.StreamID) { + s.streamsMutex.Lock() + s.Streams[id] = nil + s.streamsMutex.Unlock() +} diff --git a/session_test.go b/session_test.go index 5984e7bdb..e4bda6711 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,8 @@ package quic import ( + "io" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -56,11 +58,35 @@ var _ = Describe("Session", func() { }) It("closes streams", func() { - + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + FinBit: true, + }) + Expect(session.Streams).To(HaveLen(1)) + Expect(session.Streams[5]).ToNot(BeNil()) + Expect(callbackCalled).To(BeTrue()) + p := make([]byte, 4) + _, err := session.Streams[5].Read(p) + Expect(err).To(Equal(io.EOF)) + Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + Expect(session.Streams).To(HaveLen(1)) + Expect(session.Streams[5]).To(BeNil()) }) It("rejects streams that existed previously", func() { - + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte{}, + FinBit: true, + }) + _, err := session.Streams[5].Read([]byte{0}) + Expect(err).To(Equal(io.EOF)) + err = session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte{}, + }) + Expect(err).To(MatchError("Session: reopening streams is not allowed")) }) }) }) diff --git a/stream.go b/stream.go index 6cff4c766..d80b4a61d 100644 --- a/stream.go +++ b/stream.go @@ -9,9 +9,14 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +type streamHandler interface { + QueueFrame(frames.Frame) error + closeStream(protocol.StreamID) +} + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface type Stream struct { - Session *Session + Session streamHandler StreamID protocol.StreamID StreamFrames chan *frames.StreamFrame CurrentFrame *frames.StreamFrame @@ -23,7 +28,7 @@ type Stream struct { } // NewStream creates a new Stream -func NewStream(session *Session, StreamID protocol.StreamID) *Stream { +func NewStream(session streamHandler, StreamID protocol.StreamID) *Stream { return &Stream{ Session: session, StreamID: StreamID, @@ -56,6 +61,7 @@ func (s *Stream) Read(p []byte) (int, error) { s.currentErr = io.EOF close(s.StreamFrames) s.CurrentFrame = nil + s.Session.closeStream(s.StreamID) return bytesRead, io.EOF } s.CurrentFrame = nil @@ -131,6 +137,7 @@ func (s *Stream) ReadByte() (byte, error) { return p[0], nil } +// TODO: Test func (s *Stream) Write(p []byte) (int, error) { data := make([]byte, len(p)) copy(data, p) @@ -166,5 +173,6 @@ func (s *Stream) AddStreamFrame(frame *frames.StreamFrame) error { // stream should be closed. func (s *Stream) RegisterError(err error) { s.currentErr = err + s.Session.closeStream(s.StreamID) close(s.StreamFrames) } diff --git a/stream_test.go b/stream_test.go index 8ee03aa53..5853c4cec 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,17 +6,32 @@ import ( "time" "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +type mockStreamHandler struct { + closedStream bool +} + +func (*mockStreamHandler) QueueFrame(frames.Frame) error { + panic("not implemented") +} + +func (m *mockStreamHandler) closeStream(protocol.StreamID) { + m.closedStream = true +} + var _ = Describe("Stream", func() { var ( - stream *Stream + stream *Stream + handler *mockStreamHandler ) BeforeEach(func() { - stream = NewStream(nil, 1337) + handler = &mockStreamHandler{} + stream = NewStream(handler, 1337) }) It("reads a single StreamFrame", func() { @@ -274,6 +289,7 @@ var _ = Describe("Stream", func() { Context("closing", func() { AfterEach(func() { Expect(stream.StreamFrames).To(BeClosed()) + Expect(handler.closedStream).To(BeTrue()) }) Context("with fin bit", func() {