From: @yelihua Reviewed-by: @ouwenchang,@wangyue01 Signed-off-by: @wangyue01tags/v1.1.0
| @@ -312,7 +312,7 @@ def recheck(): | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensor_graphs", methods=["GET"]) | |||||
| @BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"]) | |||||
| def retrieve_tensor_graph(): | def retrieve_tensor_graph(): | ||||
| """ | """ | ||||
| Retrieve tensor value according to name and shape. | Retrieve tensor value according to name and shape. | ||||
| @@ -321,7 +321,7 @@ def retrieve_tensor_graph(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor_graphs?tensor_name=tensor_name$graph_name=graph_name | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name | |||||
| """ | """ | ||||
| tensor_name = request.args.get('tensor_name') | tensor_name = request.args.get('tensor_name') | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| @@ -329,7 +329,7 @@ def retrieve_tensor_graph(): | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensor_hits", methods=["GET"]) | |||||
| @BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"]) | |||||
| def retrieve_tensor_hits(): | def retrieve_tensor_hits(): | ||||
| """ | """ | ||||
| Retrieve tensor value according to name and shape. | Retrieve tensor value according to name and shape. | ||||
| @@ -338,7 +338,7 @@ def retrieve_tensor_hits(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor_hits?tensor_name=tensor_name$graph_name=graph_name | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name | |||||
| """ | """ | ||||
| tensor_name = request.args.get('tensor_name') | tensor_name = request.args.get('tensor_name') | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| @@ -478,7 +478,6 @@ class DebuggerServer: | |||||
| } | } | ||||
| - id (str): Id of condition. | - id (str): Id of condition. | ||||
| - params (list[dict]): The list of param for this condition. | - params (list[dict]): The list of param for this condition. | ||||
| watch_nodes (list[str]): The list of node names. | watch_nodes (list[str]): The list of node names. | ||||
| watch_point_id (int): The id of watchpoint. | watch_point_id (int): The id of watchpoint. | ||||
| @@ -123,7 +123,6 @@ class DebuggerGraph(MSGraph): | |||||
| - activation_func (Union[str, list[str]): The target functions. Used when node_type | - activation_func (Union[str, list[str]): The target functions. Used when node_type | ||||
| is TargetTypeEnum.ACTIVATION. | is TargetTypeEnum.ACTIVATION. | ||||
| - search_range (list[Node]): The list of nodes to be searched from. | - search_range (list[Node]): The list of nodes to be searched from. | ||||
| Returns: | Returns: | ||||
| @@ -18,12 +18,14 @@ from mindinsight.debugger.common.log import LOGGER as log | |||||
| from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum | from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum | ||||
| from .debugger_graph import DebuggerGraph | from .debugger_graph import DebuggerGraph | ||||
| class DebuggerMultiGraph(DebuggerGraph): | class DebuggerMultiGraph(DebuggerGraph): | ||||
| """The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph.""" | """The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph.""" | ||||
| def add_graph(self, graph_dict): | def add_graph(self, graph_dict): | ||||
| """ | """ | ||||
| add graphs to DebuggerMultiGraph | |||||
| Add graphs to DebuggerMultiGraph. | |||||
| Args: | Args: | ||||
| graph_dict (dict): The <graph_name, graph_object> dict. | graph_dict (dict): The <graph_name, graph_object> dict. | ||||
| """ | """ | ||||
| @@ -178,7 +178,6 @@ class Watchpoint: | |||||
| watch_condition (dict): The condition of Watchpoint. | watch_condition (dict): The condition of Watchpoint. | ||||
| - condition (str): Accept `INF` or `NAN`. | - condition (str): Accept `INF` or `NAN`. | ||||
| - param (list[float]): Not defined yet. | - param (list[float]): Not defined yet. | ||||
| """ | """ | ||||
| @@ -50,7 +50,6 @@ class TensorHandler(StreamHandlerBase): | |||||
| value (dict): The Tensor proto message. | value (dict): The Tensor proto message. | ||||
| - step (int): The current step of tensor. | - step (int): The current step of tensor. | ||||
| - tensor_protos (list[TensorProto]): The tensor proto. | - tensor_protos (list[TensorProto]): The tensor proto. | ||||
| Returns: | Returns: | ||||
| @@ -153,9 +152,7 @@ class TensorHandler(StreamHandlerBase): | |||||
| filter_condition (dict): Filter condition. | filter_condition (dict): Filter condition. | ||||
| - name (str): The full name of tensor. | - name (str): The full name of tensor. | ||||
| - node_type (str): The type of the node. | - node_type (str): The type of the node. | ||||
| - prev (bool): Whether to get previous tensor. | - prev (bool): Whether to get previous tensor. | ||||
| Returns: | Returns: | ||||
| @@ -389,11 +389,8 @@ class WatchpointHitHandler(StreamHandlerBase): | |||||
| value (dict): The watchpoint hit info. | value (dict): The watchpoint hit info. | ||||
| - tensor_proto (TensorProto): The message about hit tensor. | - tensor_proto (TensorProto): The message about hit tensor. | ||||
| - watchpoint (Watchpoint): The Watchpoint that a node hit. | - watchpoint (Watchpoint): The Watchpoint that a node hit. | ||||
| - node_name (str): The UI node name. | - node_name (str): The UI node name. | ||||
| - graph_name (str): The graph name. | - graph_name (str): The graph name. | ||||
| """ | """ | ||||
| watchpoint_hit = WatchpointHit( | watchpoint_hit = WatchpointHit( | ||||
| @@ -561,7 +558,6 @@ def validate_watch_condition_params(condition_mgr, watch_condition): | |||||
| watch_condition (dict): Watch condition. | watch_condition (dict): Watch condition. | ||||
| - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING. | - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING. | ||||
| - param (list): Condition value. Should be given for comparison condition. The value | - param (list): Condition value. Should be given for comparison condition. The value | ||||
| will be translated to np.float32. | will be translated to np.float32. | ||||
| """ | """ | ||||
| @@ -379,7 +379,7 @@ class TestAscendDebugger: | |||||
| ]) | ]) | ||||
| def test_retrieve_tensor_graph(self, app_client, body_data, expect_file): | def test_retrieve_tensor_graph(self, app_client, body_data, expect_file): | ||||
| """Test retrieve tensor graph.""" | """Test retrieve tensor graph.""" | ||||
| url = 'tensor_graphs' | |||||
| url = 'tensor-graphs' | |||||
| with self._debugger_client.get_thread_instance(): | with self._debugger_client.get_thread_instance(): | ||||
| create_watchpoint_and_wait(app_client) | create_watchpoint_and_wait(app_client) | ||||
| send_and_compare_result(app_client, url, body_data, expect_file, method='GET') | send_and_compare_result(app_client, url, body_data, expect_file, method='GET') | ||||
| @@ -637,7 +637,7 @@ class TestMultiGraphDebugger: | |||||
| ]) | ]) | ||||
| def test_retrieve_tensor_hits(self, app_client, body_data, expect_file): | def test_retrieve_tensor_hits(self, app_client, body_data, expect_file): | ||||
| """Test retrieve tensor graph.""" | """Test retrieve tensor graph.""" | ||||
| url = 'tensor_hits' | |||||
| url = 'tensor-hits' | |||||
| with self._debugger_client.get_thread_instance(): | with self._debugger_client.get_thread_instance(): | ||||
| check_waiting_state(app_client) | check_waiting_state(app_client) | ||||
| send_and_compare_result(app_client, url, body_data, expect_file, method='GET') | send_and_compare_result(app_client, url, body_data, expect_file, method='GET') | ||||
| @@ -25,19 +25,17 @@ class TestTensorHandler: | |||||
| """Test TensorHandler.""" | """Test TensorHandler.""" | ||||
| def setup_method(self): | def setup_method(self): | ||||
| """Setup method for each test case.""" | |||||
| self.tensor_handler = TensorHandler() | self.tensor_handler = TensorHandler() | ||||
| @mock.patch.object(TensorHandler, '_get_tensor') | @mock.patch.object(TensorHandler, '_get_tensor') | ||||
| @mock.patch.object(log, "error") | @mock.patch.object(log, "error") | ||||
| @pytest.mark.parametrize("filter_condition", {}) | |||||
| def test_get(self, mock_get_tensor, mock_error, filter_condition): | |||||
| """ | |||||
| Test get full tensor value. | |||||
| """ | |||||
| def test_get(self, mock_get_tensor, mock_error): | |||||
| """Test get full tensor value.""" | |||||
| mock_get_tensor.return_value = None | mock_get_tensor.return_value = None | ||||
| mock_error.return_value = None | mock_error.return_value = None | ||||
| with pytest.raises(DebuggerParamValueError) as ex: | with pytest.raises(DebuggerParamValueError) as ex: | ||||
| self.tensor_handler.get(filter_condition) | |||||
| self.tensor_handler.get({}) | |||||
| assert "No tensor named {}".format(None) in str(ex.value) | assert "No tensor named {}".format(None) in str(ex.value) | ||||
| def test_get_tensor_value_by_name_none(self): | def test_get_tensor_value_by_name_none(self): | ||||
| @@ -48,9 +46,7 @@ class TestTensorHandler: | |||||
| @mock.patch.object(log, "error") | @mock.patch.object(log, "error") | ||||
| @pytest.mark.parametrize("tensor_name", "name") | @pytest.mark.parametrize("tensor_name", "name") | ||||
| def test_get_tensors_diff_error(self, mock_error, tensor_name): | def test_get_tensors_diff_error(self, mock_error, tensor_name): | ||||
| """ | |||||
| Test get_tensors_diff. | |||||
| """ | |||||
| """Test get_tensors_diff.""" | |||||
| mock_error.return_value = None | mock_error.return_value = None | ||||
| with pytest.raises(DebuggerParamValueError) as ex: | with pytest.raises(DebuggerParamValueError) as ex: | ||||
| self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) | self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) | ||||