| @@ -78,7 +78,8 @@ class DebuggerCompareTensorError(MindInsightException): | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(DebuggerCompareTensorError, self).__init__( | super(DebuggerCompareTensorError, self).__init__( | ||||
| error=DebuggerErrors.COMPARE_TENSOR_ERROR, | error=DebuggerErrors.COMPARE_TENSOR_ERROR, | ||||
| message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) | |||||
| message=msg, | |||||
| http_code=400 | |||||
| ) | ) | ||||
| @@ -111,7 +112,8 @@ class DebuggerNodeNotInGraphError(MindInsightException): | |||||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." | err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." | ||||
| super(DebuggerNodeNotInGraphError, self).__init__( | super(DebuggerNodeNotInGraphError, self).__init__( | ||||
| error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, | error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, | ||||
| message=err_msg | |||||
| message=err_msg, | |||||
| http_code=400 | |||||
| ) | ) | ||||
| @@ -120,5 +122,6 @@ class DebuggerGraphNotExistError(MindInsightException): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(DebuggerGraphNotExistError, self).__init__( | super(DebuggerGraphNotExistError, self).__init__( | ||||
| error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, | error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, | ||||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value | |||||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value, | |||||
| http_code=400 | |||||
| ) | ) | ||||
| @@ -24,7 +24,8 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||||
| from mindinsight.datavisual.utils.tools import to_float | from mindinsight.datavisual.utils.tools import to_float | ||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | ||||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | ||||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError | |||||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ | |||||
| DebuggerCompareTensorError | |||||
| from mindinsight.debugger.common.log import logger as log | from mindinsight.debugger.common.log import logger as log | ||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | ||||
| create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo | create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo | ||||
| @@ -146,13 +147,10 @@ class DebuggerServer: | |||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | ||||
| tolerance = to_float(tolerance, 'tolerance') | tolerance = to_float(tolerance, 'tolerance') | ||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | ||||
| if detail == 'data': | |||||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||||
| else: | |||||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||||
| else: | else: | ||||
| raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) | |||||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||||
| return reply | return reply | ||||
| def retrieve(self, mode, filter_condition=None): | def retrieve(self, mode, filter_condition=None): | ||||
| @@ -177,8 +175,8 @@ class DebuggerServer: | |||||
| # validate param <mode> | # validate param <mode> | ||||
| if mode not in mode_mapping.keys(): | if mode not in mode_mapping.keys(): | ||||
| log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " | log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " | ||||
| "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) | |||||
| raise DebuggerParamTypeError("Invalid mode.") | |||||
| "'watchpoint_hit'], but got %s.", mode_mapping) | |||||
| raise DebuggerParamValueError("Invalid mode.") | |||||
| # validate backend status | # validate backend status | ||||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | ||||
| if metadata_stream.state == ServerStatus.PENDING.value: | if metadata_stream.state == ServerStatus.PENDING.value: | ||||
| @@ -12,4 +12,9 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Test for debugger module.""" | |||||
| """ | |||||
| Function: | |||||
| Unit test for debugger module. | |||||
| Usage: | |||||
| pytest tests/ut/debugger/ | |||||
| """ | |||||
| @@ -23,10 +23,12 @@ from mindinsight.debugger.common.utils import NodeBasicInfo | |||||
| from mindinsight.debugger.proto import ms_graph_pb2 | from mindinsight.debugger.proto import ms_graph_pb2 | ||||
| from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | ||||
| from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler | from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler | ||||
| from tests.utils.tools import compare_result_with_file | |||||
| GRAPH_PROTO_FILE = os.path.join( | GRAPH_PROTO_FILE = os.path.join( | ||||
| os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb' | os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb' | ||||
| ) | ) | ||||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expected_results') | |||||
| def get_graph_proto(): | def get_graph_proto(): | ||||
| @@ -137,3 +139,15 @@ def mock_tensor_history(): | |||||
| } | } | ||||
| return tensor_history | return tensor_history | ||||
| def compare_debugger_result_with_file(res, expect_file): | |||||
| """ | |||||
| Compare debugger result with file. | |||||
| Args: | |||||
| res (dict): The debugger result in dict type. | |||||
| expect_file: The expected file name. | |||||
| """ | |||||
| real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, expect_file) | |||||
| compare_result_with_file(res, real_path) | |||||
| @@ -0,0 +1 @@ | |||||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}, "graph": {}, "watch_points": []} | |||||
| @@ -0,0 +1 @@ | |||||
| {"tensor_history": [{"name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", "node_type": "TransData", "type": "output", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}, {"name": "Default/args0:0", "full_name": "Default/args0:0", "node_type": "Parameter", "type": "input", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}], "metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}} | |||||
| @@ -12,7 +12,12 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Test WatchpointHandler.""" | |||||
| """ | |||||
| Function: | |||||
| Test query debugger watchpoint handler. | |||||
| Usage: | |||||
| pytest tests/ut/debugger | |||||
| """ | |||||
| import json | import json | ||||
| import os | import os | ||||
| from unittest import mock, TestCase | from unittest import mock, TestCase | ||||
| @@ -0,0 +1,237 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Function: | |||||
| Test debugger grpc server. | |||||
| Usage: | |||||
| pytest tests/ut/debugger/test_debugger_grpc_server.py | |||||
| """ | |||||
| from unittest import mock | |||||
| from unittest.mock import MagicMock | |||||
| import numpy as np | |||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType | |||||
| from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ | |||||
| WatchpointHandler | |||||
| from tests.ut.debugger.configurations import GRAPH_PROTO_FILE | |||||
| class MockDataGenerator: | |||||
| """Mocked Data generator.""" | |||||
| @staticmethod | |||||
| def get_run_cmd(steps=0, level='step', node_name=''): | |||||
| """Get run command.""" | |||||
| event = get_ack_reply() | |||||
| event.run_cmd.run_level = level | |||||
| if level == 'node': | |||||
| event.run_cmd.node_name = node_name | |||||
| else: | |||||
| event.run_cmd.run_steps = steps | |||||
| return event | |||||
| @staticmethod | |||||
| def get_exit_cmd(): | |||||
| """Get exit command.""" | |||||
| event = get_ack_reply() | |||||
| event.exit = True | |||||
| return event | |||||
| @staticmethod | |||||
| def get_set_cmd(): | |||||
| """Get set command""" | |||||
| event = get_ack_reply() | |||||
| event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1)) | |||||
| return event | |||||
| @staticmethod | |||||
| def get_view_cmd(): | |||||
| """Get set command""" | |||||
| view_event = get_ack_reply() | |||||
| ms_tensor = view_event.view_cmd.tensors.add() | |||||
| ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' | |||||
| event = {'view_cmd': view_event, 'node_name': 'mock_node_name'} | |||||
| return event | |||||
| @staticmethod | |||||
| def get_graph_chunks(): | |||||
| """Get graph chunks.""" | |||||
| chunk_size = 1024 | |||||
| with open(GRAPH_PROTO_FILE, 'rb') as file_handler: | |||||
| content = file_handler.read() | |||||
| chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])] | |||||
| return chunks | |||||
| @staticmethod | |||||
| def get_tensors(): | |||||
| """Get tensors.""" | |||||
| tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes() | |||||
| tensor_pre = TensorProto( | |||||
| node_name='mock_node_name', | |||||
| slot='0', | |||||
| data_type=DataType.DT_FLOAT32, | |||||
| dims=[2, 3], | |||||
| tensor_content=tensor_content[:12], | |||||
| finished=0 | |||||
| ) | |||||
| tensor_succ = TensorProto() | |||||
| tensor_succ.CopyFrom(tensor_pre) | |||||
| tensor_succ.tensor_content = tensor_content[12:] | |||||
| tensor_succ.finished = 1 | |||||
| return [tensor_pre, tensor_succ] | |||||
| @staticmethod | |||||
| def get_watchpoint_hit(): | |||||
| """Get watchpoint hit.""" | |||||
| res = WatchpointHit(id=1) | |||||
| res.tensor.node_name = 'mock_node_name' | |||||
| res.tensor.slot = '0' | |||||
| return res | |||||
| class TestDebuggerGrpcServer: | |||||
| """Test debugger grpc server.""" | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| """Initialize for test class.""" | |||||
| cls._server = None | |||||
| def setup_method(self): | |||||
| """Initialize for each testcase.""" | |||||
| cache_store = DebuggerCache() | |||||
| self._server = DebuggerGrpcServer(cache_store) | |||||
| def test_waitcmd_with_pending_status(self): | |||||
| """Test wait command interface when status is pending.""" | |||||
| res = self._server.WaitCMD(MagicMock(), MagicMock()) | |||||
| assert res.status == EventReply.Status.FAILED | |||||
| @mock.patch.object(WatchpointHitHandler, 'empty', False) | |||||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command') | |||||
| def test_waitcmd_with_old_command(self, *args): | |||||
| """Test wait command interface with old command.""" | |||||
| old_command = MockDataGenerator.get_run_cmd(steps=1) | |||||
| args[0].return_value = old_command | |||||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||||
| setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'}) | |||||
| setattr(self._server, '_received_hit', True) | |||||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||||
| assert res == old_command | |||||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) | |||||
| @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') | |||||
| def test_waitcmd_with_next_command(self, *args): | |||||
| """Test wait for next command.""" | |||||
| old_command = MockDataGenerator.get_run_cmd(steps=1) | |||||
| args[0].return_value = old_command | |||||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||||
| assert res == old_command | |||||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) | |||||
| @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') | |||||
| def test_waitcmd_with_next_command_is_none(self, *args): | |||||
| """Test wait command interface with next command is None.""" | |||||
| args[0].return_value = None | |||||
| setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH) | |||||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||||
| assert res == get_ack_reply(1) | |||||
| @mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None)) | |||||
| @mock.patch.object(DebuggerCache, 'has_command') | |||||
| def test_deal_with_old_command_with_continue_steps(self, *args): | |||||
| """Test deal with old command with continue steps.""" | |||||
| args[0].side_effect = [True, False] | |||||
| setattr(self._server, '_continue_steps', 1) | |||||
| res = self._server._deal_with_old_command() | |||||
| assert res == MockDataGenerator.get_run_cmd(steps=1) | |||||
| @mock.patch.object(DebuggerCache, 'get_command') | |||||
| @mock.patch.object(DebuggerCache, 'has_command', return_value=True) | |||||
| def test_deal_with_old_command_with_exit_cmd(self, *args): | |||||
| """Test deal with exit command.""" | |||||
| cmd = MockDataGenerator.get_exit_cmd() | |||||
| args[1].return_value = ('0', cmd) | |||||
| res = self._server._deal_with_old_command() | |||||
| assert res == cmd | |||||
| @mock.patch.object(DebuggerCache, 'get_command') | |||||
| @mock.patch.object(DebuggerCache, 'has_command', return_value=True) | |||||
| def test_deal_with_old_command_with_view_cmd(self, *args): | |||||
| """Test deal with view command.""" | |||||
| cmd = MockDataGenerator.get_view_cmd() | |||||
| args[1].return_value = ('0', cmd) | |||||
| res = self._server._deal_with_old_command() | |||||
| assert res == cmd.get('view_cmd') | |||||
| expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True} | |||||
| assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd | |||||
| @mock.patch.object(DebuggerCache, 'get_command') | |||||
| def test_wait_for_run_command(self, *args): | |||||
| """Test wait for run command.""" | |||||
| cmd = MockDataGenerator.get_run_cmd(steps=2) | |||||
| args[0].return_value = ('0', cmd) | |||||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||||
| res = self._server._wait_for_next_command() | |||||
| assert res == MockDataGenerator.get_run_cmd(steps=1) | |||||
| assert getattr(self._server, '_continue_steps') == 1 | |||||
| @mock.patch.object(DebuggerCache, 'get_command') | |||||
| def test_wait_for_pause_and_run_command(self, *args): | |||||
| """Test wait for run command.""" | |||||
| pause_cmd = MockDataGenerator.get_run_cmd(steps=0) | |||||
| empty_view_cmd = MockDataGenerator.get_view_cmd() | |||||
| empty_view_cmd.pop('node_name') | |||||
| run_cmd = MockDataGenerator.get_run_cmd(steps=2) | |||||
| args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)] | |||||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||||
| res = self._server._wait_for_next_command() | |||||
| assert res == run_cmd | |||||
| assert getattr(self._server, '_continue_steps') == 1 | |||||
| def test_send_matadata(self): | |||||
| """Test SendMatadata interface.""" | |||||
| res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock()) | |||||
| assert res == get_ack_reply() | |||||
| def test_send_matadata_with_training_done(self): | |||||
| """Test SendMatadata interface.""" | |||||
| res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock()) | |||||
| assert res == get_ack_reply() | |||||
| def test_send_graph(self): | |||||
| """Test SendGraph interface.""" | |||||
| res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock()) | |||||
| assert res == get_ack_reply() | |||||
| def test_send_tensors(self): | |||||
| """Test SendTensors interface.""" | |||||
| res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock()) | |||||
| assert res == get_ack_reply() | |||||
| @mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id') | |||||
| @mock.patch.object(GraphHandler, 'get_node_name_by_full_name') | |||||
| def test_send_watchpoint_hit(self, *args): | |||||
| """Test SendWatchpointHits interface.""" | |||||
| args[0].side_effect = [None, 'mock_full_name'] | |||||
| watchpoint_hit = MockDataGenerator.get_watchpoint_hit() | |||||
| res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock()) | |||||
| assert res == get_ack_reply() | |||||
| @@ -0,0 +1,250 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Function: | |||||
| Test debugger server. | |||||
| Usage: | |||||
| pytest tests/ut/debugger/test_debugger_server.py | |||||
| """ | |||||
| import signal | |||||
| from threading import Thread | |||||
| from unittest import mock | |||||
| from unittest.mock import MagicMock | |||||
| import grpc | |||||
| import pytest | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||||
| DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | |||||
| TensorHandler | |||||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | |||||
| class TestDebuggerServer: | |||||
| """Test debugger server.""" | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| """Initialize for test class.""" | |||||
| cls._server = None | |||||
| def setup_method(self): | |||||
| """Prepare debugger server object.""" | |||||
| self._server = DebuggerServer() | |||||
| @mock.patch.object(signal, 'signal') | |||||
| @mock.patch.object(Thread, 'join') | |||||
| @mock.patch.object(Thread, 'start') | |||||
| @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') | |||||
| @mock.patch.object(grpc, 'server') | |||||
| def test_stop_server(self, *args): | |||||
| """Test stop debugger server.""" | |||||
| mock_grpc_server_manager = MagicMock() | |||||
| args[0].return_value = mock_grpc_server_manager | |||||
| self._server.start() | |||||
| self._server._stop_handler(MagicMock(), MagicMock()) | |||||
| assert self._server.back_server is not None | |||||
| assert self._server.grpc_server_manager == mock_grpc_server_manager | |||||
| @mock.patch.object(DebuggerCache, 'get_data') | |||||
| def test_poll_data(self, *args): | |||||
| """Test poll data request.""" | |||||
| mock_data = {'pos': 'mock_data'} | |||||
| args[0].return_value = mock_data | |||||
| res = self._server.poll_data('0') | |||||
| assert res == mock_data | |||||
| def test_poll_data_with_exept(self): | |||||
| """Test poll data with wrong input.""" | |||||
| with pytest.raises(DebuggerParamValueError, match='Pos should be string.'): | |||||
| self._server.poll_data(1) | |||||
| @mock.patch.object(GraphHandler, 'search_nodes') | |||||
| def test_search(self, *args): | |||||
| """Test search node.""" | |||||
| mock_graph = {'nodes': ['mock_nodes']} | |||||
| args[0].return_value = mock_graph | |||||
| res = self._server.search('mock_name') | |||||
| assert res == mock_graph | |||||
| def test_tensor_comparision_with_wrong_status(self): | |||||
| """Test tensor comparison with wrong status.""" | |||||
| with pytest.raises( | |||||
| DebuggerCompareTensorError, | |||||
| match='Failed to compare tensors as the MindSpore is not in waiting state.'): | |||||
| self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||||
| def test_tensor_comparision_with_wrong_type(self, *args): | |||||
| """Test tensor comparison with wrong type.""" | |||||
| args[1].return_value = 'name_scope' | |||||
| with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'): | |||||
| self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') | |||||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||||
| @mock.patch.object(TensorHandler, 'get_tensors_diff') | |||||
| def test_tensor_comparision(self, *args): | |||||
| """Test tensor comparison""" | |||||
| mock_diff_res = {'tensor_value': {}} | |||||
| args[0].return_value = mock_diff_res | |||||
| res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]') | |||||
| assert res == mock_diff_res | |||||
| def test_retrieve_with_pending(self): | |||||
| """Test retrieve request in pending status.""" | |||||
| res = self._server.retrieve(mode='all') | |||||
| assert res.get('metadata', {}).get('state') == 'pending' | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| def test_retrieve_all(self): | |||||
| """Test retrieve request.""" | |||||
| res = self._server.retrieve(mode='all') | |||||
| compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json') | |||||
| def test_retrieve_with_invalid_mode(self): | |||||
| """Test retrieve with invalid mode.""" | |||||
| with pytest.raises(DebuggerParamValueError, match='Invalid mode.'): | |||||
| self._server.retrieve(mode='invalid_mode') | |||||
| @mock.patch.object(GraphHandler, 'get') | |||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope') | |||||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||||
| def test_retrieve_node(self, *args): | |||||
| """Test retrieve node information.""" | |||||
| mock_graph = {'graph': {}} | |||||
| args[2].return_value = mock_graph | |||||
| res = self._server._retrieve_node({'name': 'mock_node_name'}) | |||||
| assert res == mock_graph | |||||
| def test_retrieve_tensor_history_with_pending(self): | |||||
| """Test retrieve request in pending status.""" | |||||
| res = self._server.retrieve_tensor_history('mock_node_name') | |||||
| assert res.get('metadata', {}).get('state') == 'pending' | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(GraphHandler, 'get_tensor_history') | |||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') | |||||
| def test_retrieve_tensor_history(self, *args): | |||||
| """Test retrieve tensor history.""" | |||||
| args[1].return_value = mock_tensor_history() | |||||
| res = self._server.retrieve_tensor_history('mock_node_name') | |||||
| compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json') | |||||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||||
| def test_validate_leaf_name(self, *args): | |||||
| """Test validate leaf name.""" | |||||
| args[0].return_value = 'name_scope' | |||||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | |||||
| self._server._validate_leaf_name(node_name='mock_node_name') | |||||
| @mock.patch.object(TensorHandler, 'get') | |||||
| @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name') | |||||
| def test_retrieve_tensor_value(self, *args): | |||||
| """Test retrieve tensor value.""" | |||||
| mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}} | |||||
| args[0].return_value = ('Parameter', 'mock_node_name') | |||||
| args[1].return_value = mock_tensor_value | |||||
| res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]') | |||||
| assert res == mock_tensor_value | |||||
| @mock.patch.object(WatchpointHandler, 'get') | |||||
| def test_retrieve_watchpoints(self, *args): | |||||
| """Test retrieve watchpoints.""" | |||||
| mock_watchpoint = {'watch_points': {}} | |||||
| args[0].return_value = mock_watchpoint | |||||
| res = self._server._retrieve_watchpoint({}) | |||||
| assert res == mock_watchpoint | |||||
| @mock.patch.object(DebuggerServer, '_retrieve_node') | |||||
| def test_retrieve_watchpoint(self, *args): | |||||
| """Test retrieve single watchpoint.""" | |||||
| mock_watchpoint = {'nodes': {}} | |||||
| args[0].return_value = mock_watchpoint | |||||
| res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | |||||
| assert res == mock_watchpoint | |||||
| @mock.patch.object(DebuggerServer, '_validate_leaf_name') | |||||
| @mock.patch.object(DebuggerServer, '_get_tensor_history') | |||||
| @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) | |||||
| def test_retrieve_watchpoint_hit(self, *args): | |||||
| """Test retrieve single watchpoint.""" | |||||
| args[1].return_value = {'tensor_history': {}} | |||||
| res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) | |||||
| assert res == {'tensor_history': {}, 'graph': {}} | |||||
| def test_create_watchpoint_with_wrong_state(self): | |||||
| """Test create watchpoint with wrong state.""" | |||||
| with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | |||||
| self._server.create_watchpoint(watch_condition={'condition': 'INF'}) | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') | |||||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') | |||||
| @mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()]) | |||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | |||||
| @mock.patch.object(WatchpointHandler, 'create_watchpoint') | |||||
| def test_create_watchpoint(self, *args): | |||||
| """Test create watchpoint.""" | |||||
| args[0].return_value = 1 | |||||
| res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name']) | |||||
| assert res == {'id': 1} | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(GraphHandler, 'get_searched_node_list') | |||||
| @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id') | |||||
| @mock.patch.object(WatchpointHandler, 'update_watchpoint') | |||||
| def test_update_watchpoint(self, *args): | |||||
| """Test update watchpoint.""" | |||||
| args[2].return_value = [MagicMock(name='seatch_name/op_name')] | |||||
| res = self._server.update_watchpoint( | |||||
| watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name') | |||||
| assert res == {} | |||||
| def test_delete_watchpoint_with_wrong_state(self): | |||||
| """Test delete watchpoint with wrong state.""" | |||||
| with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'): | |||||
| self._server.delete_watchpoint(watch_point_id=1) | |||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||||
| @mock.patch.object(WatchpointHandler, 'delete_watchpoint') | |||||
| def test_delete_watchpoint(self, *args): | |||||
| """Test delete watchpoint with wrong state.""" | |||||
| args[0].return_value = None | |||||
| res = self._server.delete_watchpoint(1) | |||||
| assert res == {} | |||||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||||
| ('continue', 'waiting', 'running'), | |||||
| ('pause', 'running', 'waiting'), | |||||
| ('terminate', 'waiting', 'pending')]) | |||||
| def test_control(self, mode, cur_state, state): | |||||
| """Test control request.""" | |||||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | |||||
| res = self._server.control({'mode': mode}) | |||||
| assert res == {'metadata': {'state': state}} | |||||
| def test_construct_run_event(self): | |||||
| """Test construct run event.""" | |||||
| res = self._server._construct_run_event({'level': 'node'}) | |||||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | |||||