From 16da08a440c64df044d51b8750140c04ae5f563c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 13 Dec 2016 11:44:40 +0700 Subject: [PATCH] add client functionality to the streamsMap --- packet_packer_test.go | 2 +- session.go | 4 ++-- session_test.go | 2 +- stream_framer_test.go | 2 +- streams_map.go | 28 ++++++++++++++++++++++----- streams_map_test.go | 45 ++++++++++++++++++++++++++++++++++++++++--- 6 files changed, 70 insertions(+), 13 deletions(-) diff --git a/packet_packer_test.go b/packet_packer_test.go index 07ef076f..da219322 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -41,7 +41,7 @@ var _ = Describe("Packet packer", func() { fcm.sendWindowSizes[7] = protocol.MaxByteCount cpm := &mockConnectionParametersManager{} - streamFramer = newStreamFramer(newStreamsMap(nil, cpm), fcm) + streamFramer = newStreamFramer(newStreamsMap(nil, protocol.PerspectiveServer, cpm), fcm) packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{}, diff --git a/session.go b/session.go index a28575ef..5727291b 100644 --- a/session.go +++ b/session.go @@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) session.setup() - cryptoStream, _ := session.GetOrOpenStream(1) + cryptoStream, _ := session.OpenStream(1) var err error session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged) if err != nil { @@ -174,7 +174,7 @@ func (s *Session) setup() { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.streamsMap = newStreamsMap(s.newStream, s.connectionParameters) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) } diff --git a/session_test.go b/session_test.go index fd511651..2e6004eb 100644 --- a/session_test.go +++ b/session_test.go @@ -159,7 +159,7 @@ var _ = Describe("Session", func() { func(protocol.ConnectionID) { closeCallbackCalled = true }, ) Expect(err).ToNot(HaveOccurred()) - Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) + Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream }) diff --git a/stream_framer_test.go b/stream_framer_test.go index a8802acf..9226b2a8 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -31,7 +31,7 @@ var _ = Describe("Stream Framer", func() { stream1 = &stream{streamID: 10} stream2 = &stream{streamID: 11} - streamsMap = newStreamsMap(nil, &mockConnectionParametersManager{}) + streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, &mockConnectionParametersManager{}) streamsMap.putStream(stream1) streamsMap.putStream(stream2) diff --git a/streams_map.go b/streams_map.go index b4606772..7666497c 100644 --- a/streams_map.go +++ b/streams_map.go @@ -13,6 +13,7 @@ import ( type streamsMap struct { mutex sync.RWMutex + perspective protocol.Perspective connectionParameters handshake.ConnectionParametersManager streams map[protocol.StreamID]*stream @@ -38,8 +39,9 @@ var ( errMapAccess = errors.New("streamsMap: Error accessing the streams map") ) -func newStreamsMap(newStream newStreamLambda, connectionParameters handshake.ConnectionParametersManager) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { return &streamsMap{ + perspective: pers, streams: map[protocol.StreamID]*stream{}, openStreams: make([]protocol.StreamID, 0), newStream: newStream, @@ -68,9 +70,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } - if id%2 == 0 { + if m.perspective == protocol.PerspectiveServer && id%2 == 0 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) } + if m.perspective == protocol.PerspectiveClient && id%2 == 1 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient)) } @@ -79,7 +84,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if err != nil { return nil, err } - m.numIncomingStreams++ + + if m.perspective == protocol.PerspectiveServer { + m.numIncomingStreams++ + } else { + m.numOutgoingStreams++ + } if id > m.highestStreamOpenedByClient { m.highestStreamOpenedByClient = id @@ -97,9 +107,12 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { // OpenStream opens a stream from the server's side func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { - if id%2 == 1 { + if m.perspective == protocol.PerspectiveServer && id%2 == 1 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) } + if m.perspective == protocol.PerspectiveClient && id%2 == 0 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + } m.mutex.Lock() defer m.mutex.Unlock() @@ -115,7 +128,12 @@ func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { if err != nil { return nil, err } - m.numOutgoingStreams++ + + if m.perspective == protocol.PerspectiveServer { + m.numOutgoingStreams++ + } else { + m.numIncomingStreams++ + } m.putStream(s) return s, nil diff --git a/streams_map_test.go b/streams_map_test.go index 48ab611d..234f2b54 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -59,7 +59,7 @@ var _ = Describe("Streams Map", func() { maxIncomingStreams: 75, maxOutgoingStreams: 60, } - m = newStreamsMap(nil, cpm) + m = newStreamsMap(nil, protocol.PerspectiveServer, cpm) }) Context("getting and creating streams", func() { @@ -77,7 +77,7 @@ var _ = Describe("Streams Map", func() { Expect(m.numOutgoingStreams).To(BeZero()) }) - Context("client-side streams", func() { + Context("client-side streams, as a server", func() { It("rejects streams with even IDs", func() { _, err := m.GetOrOpenStream(6) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) @@ -129,7 +129,26 @@ var _ = Describe("Streams Map", func() { }) }) - Context("server-side streams", func() { + Context("client-side streams, as a client", func() { + BeforeEach(func() { + m.perspective = protocol.PerspectiveClient + }) + + It("rejects streams with odd IDs", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) + }) + + It("gets new streams", func() { + s, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(6))) + Expect(m.numOutgoingStreams).To(Equal(uint32(1))) + Expect(m.numIncomingStreams).To(BeZero()) + }) + }) + + Context("server-side streams, as a server", func() { It("rejects streams with odd IDs", func() { _, err := m.OpenStream(5) Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) @@ -188,6 +207,26 @@ var _ = Describe("Streams Map", func() { }) }) + Context("server-side streams, as a client", func() { + BeforeEach(func() { + m.perspective = protocol.PerspectiveClient + }) + + It("rejects streams with even IDs", func() { + _, err := m.OpenStream(6) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) + }) + + It("opens a new stream", func() { + s, err := m.OpenStream(7) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(7))) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(Equal(uint32(1))) + }) + }) + Context("DoS mitigation", func() { It("opens and closes a lot of streams", func() { for i := 1; i < 2*protocol.MaxNewStreamIDDelta; i += 2 {