Browse Source

!868 change the docstring format and restful api url

From: @yelihua
Reviewed-by: @ouwenchang,@wangyue01
Signed-off-by: @wangyue01
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
194a16be46
9 changed files with 14 additions and 26 deletions
  1. +4
    -4
      mindinsight/backend/debugger/debugger_api.py
  2. +0
    -1
      mindinsight/debugger/debugger_server.py
  3. +0
    -1
      mindinsight/debugger/stream_cache/debugger_graph.py
  4. +3
    -1
      mindinsight/debugger/stream_cache/debugger_multigraph.py
  5. +0
    -1
      mindinsight/debugger/stream_cache/watchpoint.py
  6. +0
    -3
      mindinsight/debugger/stream_handler/tensor_handler.py
  7. +0
    -4
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  8. +2
    -2
      tests/st/func/debugger/test_restful_api.py
  9. +5
    -9
      tests/ut/debugger/stream_handler/test_tensor_handler.py

+ 4
- 4
mindinsight/backend/debugger/debugger_api.py View File

@@ -312,7 +312,7 @@ def recheck():
return reply return reply




@BLUEPRINT.route("/debugger/tensor_graphs", methods=["GET"])
@BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"])
def retrieve_tensor_graph(): def retrieve_tensor_graph():
""" """
Retrieve tensor value according to name and shape. Retrieve tensor value according to name and shape.
@@ -321,7 +321,7 @@ def retrieve_tensor_graph():
str, the required data. str, the required data.


Examples: Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor_graphs?tensor_name=tensor_name$graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name
""" """
tensor_name = request.args.get('tensor_name') tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
@@ -329,7 +329,7 @@ def retrieve_tensor_graph():
return reply return reply




@BLUEPRINT.route("/debugger/tensor_hits", methods=["GET"])
@BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"])
def retrieve_tensor_hits(): def retrieve_tensor_hits():
""" """
Retrieve tensor value according to name and shape. Retrieve tensor value according to name and shape.
@@ -338,7 +338,7 @@ def retrieve_tensor_hits():
str, the required data. str, the required data.


Examples: Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor_hits?tensor_name=tensor_name$graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name
""" """
tensor_name = request.args.get('tensor_name') tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')


+ 0
- 1
mindinsight/debugger/debugger_server.py View File

@@ -478,7 +478,6 @@ class DebuggerServer:
} }


- id (str): Id of condition. - id (str): Id of condition.

- params (list[dict]): The list of param for this condition. - params (list[dict]): The list of param for this condition.
watch_nodes (list[str]): The list of node names. watch_nodes (list[str]): The list of node names.
watch_point_id (int): The id of watchpoint. watch_point_id (int): The id of watchpoint.


+ 0
- 1
mindinsight/debugger/stream_cache/debugger_graph.py View File

@@ -123,7 +123,6 @@ class DebuggerGraph(MSGraph):


- activation_func (Union[str, list[str]): The target functions. Used when node_type - activation_func (Union[str, list[str]): The target functions. Used when node_type
is TargetTypeEnum.ACTIVATION. is TargetTypeEnum.ACTIVATION.

- search_range (list[Node]): The list of nodes to be searched from. - search_range (list[Node]): The list of nodes to be searched from.


Returns: Returns:


+ 3
- 1
mindinsight/debugger/stream_cache/debugger_multigraph.py View File

@@ -18,12 +18,14 @@ from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum
from .debugger_graph import DebuggerGraph from .debugger_graph import DebuggerGraph



class DebuggerMultiGraph(DebuggerGraph): class DebuggerMultiGraph(DebuggerGraph):
"""The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph.""" """The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph."""


def add_graph(self, graph_dict): def add_graph(self, graph_dict):
""" """
add graphs to DebuggerMultiGraph
Add graphs to DebuggerMultiGraph.

Args: Args:
graph_dict (dict): The <graph_name, graph_object> dict. graph_dict (dict): The <graph_name, graph_object> dict.
""" """


+ 0
- 1
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -178,7 +178,6 @@ class Watchpoint:
watch_condition (dict): The condition of Watchpoint. watch_condition (dict): The condition of Watchpoint.


- condition (str): Accept `INF` or `NAN`. - condition (str): Accept `INF` or `NAN`.

- param (list[float]): Not defined yet. - param (list[float]): Not defined yet.
""" """




+ 0
- 3
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -50,7 +50,6 @@ class TensorHandler(StreamHandlerBase):
value (dict): The Tensor proto message. value (dict): The Tensor proto message.


- step (int): The current step of tensor. - step (int): The current step of tensor.

- tensor_protos (list[TensorProto]): The tensor proto. - tensor_protos (list[TensorProto]): The tensor proto.


Returns: Returns:
@@ -153,9 +152,7 @@ class TensorHandler(StreamHandlerBase):
filter_condition (dict): Filter condition. filter_condition (dict): Filter condition.


- name (str): The full name of tensor. - name (str): The full name of tensor.

- node_type (str): The type of the node. - node_type (str): The type of the node.

- prev (bool): Whether to get previous tensor. - prev (bool): Whether to get previous tensor.


Returns: Returns:


+ 0
- 4
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -389,11 +389,8 @@ class WatchpointHitHandler(StreamHandlerBase):
value (dict): The watchpoint hit info. value (dict): The watchpoint hit info.


- tensor_proto (TensorProto): The message about hit tensor. - tensor_proto (TensorProto): The message about hit tensor.

- watchpoint (Watchpoint): The Watchpoint that a node hit. - watchpoint (Watchpoint): The Watchpoint that a node hit.

- node_name (str): The UI node name. - node_name (str): The UI node name.

- graph_name (str): The graph name. - graph_name (str): The graph name.
""" """
watchpoint_hit = WatchpointHit( watchpoint_hit = WatchpointHit(
@@ -561,7 +558,6 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
watch_condition (dict): Watch condition. watch_condition (dict): Watch condition.


- id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING. - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.

- param (list): Condition value. Should be given for comparison condition. The value - param (list): Condition value. Should be given for comparison condition. The value
will be translated to np.float32. will be translated to np.float32.
""" """


+ 2
- 2
tests/st/func/debugger/test_restful_api.py View File

@@ -379,7 +379,7 @@ class TestAscendDebugger:
]) ])
def test_retrieve_tensor_graph(self, app_client, body_data, expect_file): def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
"""Test retrieve tensor graph.""" """Test retrieve tensor graph."""
url = 'tensor_graphs'
url = 'tensor-graphs'
with self._debugger_client.get_thread_instance(): with self._debugger_client.get_thread_instance():
create_watchpoint_and_wait(app_client) create_watchpoint_and_wait(app_client)
send_and_compare_result(app_client, url, body_data, expect_file, method='GET') send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
@@ -637,7 +637,7 @@ class TestMultiGraphDebugger:
]) ])
def test_retrieve_tensor_hits(self, app_client, body_data, expect_file): def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
"""Test retrieve tensor graph.""" """Test retrieve tensor graph."""
url = 'tensor_hits'
url = 'tensor-hits'
with self._debugger_client.get_thread_instance(): with self._debugger_client.get_thread_instance():
check_waiting_state(app_client) check_waiting_state(app_client)
send_and_compare_result(app_client, url, body_data, expect_file, method='GET') send_and_compare_result(app_client, url, body_data, expect_file, method='GET')


+ 5
- 9
tests/ut/debugger/stream_handler/test_tensor_handler.py View File

@@ -25,19 +25,17 @@ class TestTensorHandler:
"""Test TensorHandler.""" """Test TensorHandler."""


def setup_method(self): def setup_method(self):
"""Setup method for each test case."""
self.tensor_handler = TensorHandler() self.tensor_handler = TensorHandler()


@mock.patch.object(TensorHandler, '_get_tensor') @mock.patch.object(TensorHandler, '_get_tensor')
@mock.patch.object(log, "error") @mock.patch.object(log, "error")
@pytest.mark.parametrize("filter_condition", {})
def test_get(self, mock_get_tensor, mock_error, filter_condition):
"""
Test get full tensor value.
"""
def test_get(self, mock_get_tensor, mock_error):
"""Test get full tensor value."""
mock_get_tensor.return_value = None mock_get_tensor.return_value = None
mock_error.return_value = None mock_error.return_value = None
with pytest.raises(DebuggerParamValueError) as ex: with pytest.raises(DebuggerParamValueError) as ex:
self.tensor_handler.get(filter_condition)
self.tensor_handler.get({})
assert "No tensor named {}".format(None) in str(ex.value) assert "No tensor named {}".format(None) in str(ex.value)


def test_get_tensor_value_by_name_none(self): def test_get_tensor_value_by_name_none(self):
@@ -48,9 +46,7 @@ class TestTensorHandler:
@mock.patch.object(log, "error") @mock.patch.object(log, "error")
@pytest.mark.parametrize("tensor_name", "name") @pytest.mark.parametrize("tensor_name", "name")
def test_get_tensors_diff_error(self, mock_error, tensor_name): def test_get_tensors_diff_error(self, mock_error, tensor_name):
"""
Test get_tensors_diff.
"""
"""Test get_tensors_diff."""
mock_error.return_value = None mock_error.return_value = None
with pytest.raises(DebuggerParamValueError) as ex: with pytest.raises(DebuggerParamValueError) as ex:
self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) self.tensor_handler.get_tensors_diff(tensor_name, {1, 1})


Loading…
Cancel
Save