Browse Source

!868 change the docstring format and restful api url

From: @yelihua
Reviewed-by: @ouwenchang,@wangyue01
Signed-off-by: @wangyue01
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
194a16be46
9 changed files with 14 additions and 26 deletions
  1. +4
    -4
      mindinsight/backend/debugger/debugger_api.py
  2. +0
    -1
      mindinsight/debugger/debugger_server.py
  3. +0
    -1
      mindinsight/debugger/stream_cache/debugger_graph.py
  4. +3
    -1
      mindinsight/debugger/stream_cache/debugger_multigraph.py
  5. +0
    -1
      mindinsight/debugger/stream_cache/watchpoint.py
  6. +0
    -3
      mindinsight/debugger/stream_handler/tensor_handler.py
  7. +0
    -4
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  8. +2
    -2
      tests/st/func/debugger/test_restful_api.py
  9. +5
    -9
      tests/ut/debugger/stream_handler/test_tensor_handler.py

+ 4
- 4
mindinsight/backend/debugger/debugger_api.py View File

@@ -312,7 +312,7 @@ def recheck():
return reply


@BLUEPRINT.route("/debugger/tensor_graphs", methods=["GET"])
@BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"])
def retrieve_tensor_graph():
"""
Retrieve tensor value according to name and shape.
@@ -321,7 +321,7 @@ def retrieve_tensor_graph():
str, the required data.

Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor_graphs?tensor_name=tensor_name$graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name
"""
tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name')
@@ -329,7 +329,7 @@ def retrieve_tensor_graph():
return reply


@BLUEPRINT.route("/debugger/tensor_hits", methods=["GET"])
@BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"])
def retrieve_tensor_hits():
"""
Retrieve tensor value according to name and shape.
@@ -338,7 +338,7 @@ def retrieve_tensor_hits():
str, the required data.

Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor_hits?tensor_name=tensor_name$graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name
"""
tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name')


+ 0
- 1
mindinsight/debugger/debugger_server.py View File

@@ -478,7 +478,6 @@ class DebuggerServer:
}

- id (str): Id of condition.

- params (list[dict]): The list of param for this condition.
watch_nodes (list[str]): The list of node names.
watch_point_id (int): The id of watchpoint.


+ 0
- 1
mindinsight/debugger/stream_cache/debugger_graph.py View File

@@ -123,7 +123,6 @@ class DebuggerGraph(MSGraph):

- activation_func (Union[str, list[str]): The target functions. Used when node_type
is TargetTypeEnum.ACTIVATION.

- search_range (list[Node]): The list of nodes to be searched from.

Returns:


+ 3
- 1
mindinsight/debugger/stream_cache/debugger_multigraph.py View File

@@ -18,12 +18,14 @@ from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum
from .debugger_graph import DebuggerGraph


class DebuggerMultiGraph(DebuggerGraph):
"""The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph."""

def add_graph(self, graph_dict):
"""
add graphs to DebuggerMultiGraph
Add graphs to DebuggerMultiGraph.

Args:
graph_dict (dict): The <graph_name, graph_object> dict.
"""


+ 0
- 1
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -178,7 +178,6 @@ class Watchpoint:
watch_condition (dict): The condition of Watchpoint.

- condition (str): Accept `INF` or `NAN`.

- param (list[float]): Not defined yet.
"""



+ 0
- 3
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -50,7 +50,6 @@ class TensorHandler(StreamHandlerBase):
value (dict): The Tensor proto message.

- step (int): The current step of tensor.

- tensor_protos (list[TensorProto]): The tensor proto.

Returns:
@@ -153,9 +152,7 @@ class TensorHandler(StreamHandlerBase):
filter_condition (dict): Filter condition.

- name (str): The full name of tensor.

- node_type (str): The type of the node.

- prev (bool): Whether to get previous tensor.

Returns:


+ 0
- 4
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -389,11 +389,8 @@ class WatchpointHitHandler(StreamHandlerBase):
value (dict): The watchpoint hit info.

- tensor_proto (TensorProto): The message about hit tensor.

- watchpoint (Watchpoint): The Watchpoint that a node hit.

- node_name (str): The UI node name.

- graph_name (str): The graph name.
"""
watchpoint_hit = WatchpointHit(
@@ -561,7 +558,6 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
watch_condition (dict): Watch condition.

- id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.

- param (list): Condition value. Should be given for comparison condition. The value
will be translated to np.float32.
"""


+ 2
- 2
tests/st/func/debugger/test_restful_api.py View File

@@ -379,7 +379,7 @@ class TestAscendDebugger:
])
def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
"""Test retrieve tensor graph."""
url = 'tensor_graphs'
url = 'tensor-graphs'
with self._debugger_client.get_thread_instance():
create_watchpoint_and_wait(app_client)
send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
@@ -637,7 +637,7 @@ class TestMultiGraphDebugger:
])
def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
"""Test retrieve tensor graph."""
url = 'tensor_hits'
url = 'tensor-hits'
with self._debugger_client.get_thread_instance():
check_waiting_state(app_client)
send_and_compare_result(app_client, url, body_data, expect_file, method='GET')


+ 5
- 9
tests/ut/debugger/stream_handler/test_tensor_handler.py View File

@@ -25,19 +25,17 @@ class TestTensorHandler:
"""Test TensorHandler."""

def setup_method(self):
"""Setup method for each test case."""
self.tensor_handler = TensorHandler()

@mock.patch.object(TensorHandler, '_get_tensor')
@mock.patch.object(log, "error")
@pytest.mark.parametrize("filter_condition", {})
def test_get(self, mock_get_tensor, mock_error, filter_condition):
"""
Test get full tensor value.
"""
def test_get(self, mock_get_tensor, mock_error):
"""Test get full tensor value."""
mock_get_tensor.return_value = None
mock_error.return_value = None
with pytest.raises(DebuggerParamValueError) as ex:
self.tensor_handler.get(filter_condition)
self.tensor_handler.get({})
assert "No tensor named {}".format(None) in str(ex.value)

def test_get_tensor_value_by_name_none(self):
@@ -48,9 +46,7 @@ class TestTensorHandler:
@mock.patch.object(log, "error")
@pytest.mark.parametrize("tensor_name", "name")
def test_get_tensors_diff_error(self, mock_error, tensor_name):
"""
Test get_tensors_diff.
"""
"""Test get_tensors_diff."""
mock_error.return_value = None
with pytest.raises(DebuggerParamValueError) as ex:
self.tensor_handler.get_tensors_diff(tensor_name, {1, 1})


Loading…
Cancel
Save