| @@ -78,7 +78,8 @@ class DebuggerCompareTensorError(MindInsightException): | |||
| def __init__(self, msg): | |||
| super(DebuggerCompareTensorError, self).__init__( | |||
| error=DebuggerErrors.COMPARE_TENSOR_ERROR, | |||
| message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) | |||
| message=msg, | |||
| http_code=400 | |||
| ) | |||
| @@ -111,7 +112,8 @@ class DebuggerNodeNotInGraphError(MindInsightException): | |||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." | |||
| super(DebuggerNodeNotInGraphError, self).__init__( | |||
| error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, | |||
| message=err_msg | |||
| message=err_msg, | |||
| http_code=400 | |||
| ) | |||
| @@ -120,5 +122,6 @@ class DebuggerGraphNotExistError(MindInsightException): | |||
| def __init__(self): | |||
| super(DebuggerGraphNotExistError, self).__init__( | |||
| error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, | |||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value | |||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| @@ -24,7 +24,8 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError | |||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ | |||
| DebuggerCompareTensorError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo | |||
| @@ -146,13 +147,10 @@ class DebuggerServer: | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||
| tolerance = to_float(tolerance, 'tolerance') | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| if detail == 'data': | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||
| else: | |||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||
| else: | |||
| raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) | |||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||
| return reply | |||
| def retrieve(self, mode, filter_condition=None): | |||
| @@ -177,8 +175,8 @@ class DebuggerServer: | |||
| # validate param <mode> | |||
| if mode not in mode_mapping.keys(): | |||
| log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " | |||
| "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) | |||
| raise DebuggerParamTypeError("Invalid mode.") | |||
| "'watchpoint_hit'], but got %s.", mode_mapping) | |||
| raise DebuggerParamValueError("Invalid mode.") | |||
| # validate backend status | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.state == ServerStatus.PENDING.value: | |||
| @@ -12,4 +12,9 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test for debugger module.""" | |||
| """ | |||
| Function: | |||
| Unit test for debugger module. | |||
| Usage: | |||
| pytest tests/ut/debugger/ | |||
| """ | |||
| @@ -23,10 +23,12 @@ from mindinsight.debugger.common.utils import NodeBasicInfo | |||
| from mindinsight.debugger.proto import ms_graph_pb2 | |||
| from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | |||
| from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler | |||
| from tests.utils.tools import compare_result_with_file | |||
| GRAPH_PROTO_FILE = os.path.join( | |||
| os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb' | |||
| ) | |||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expected_results') | |||
| def get_graph_proto(): | |||
| @@ -137,3 +139,15 @@ def mock_tensor_history(): | |||
| } | |||
| return tensor_history | |||
| def compare_debugger_result_with_file(res, expect_file): | |||
| """ | |||
| Compare debugger result with file. | |||
| Args: | |||
| res (dict): The debugger result in dict type. | |||
| expect_file: The expected file name. | |||
| """ | |||
| real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, expect_file) | |||
| compare_result_with_file(res, real_path) | |||
| @@ -0,0 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}, "graph": {}, "watch_points": []} | |||
| @@ -0,0 +1 @@ | |||
| {"tensor_history": [{"name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", "node_type": "TransData", "type": "output", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}, {"name": "Default/args0:0", "full_name": "Default/args0:0", "node_type": "Parameter", "type": "input", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}], "metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}} | |||
| @@ -12,7 +12,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test WatchpointHandler.""" | |||
| """ | |||
| Function: | |||
| Test query debugger watchpoint handler. | |||
| Usage: | |||
| pytest tests/ut/debugger | |||
| """ | |||
| import json | |||
| import os | |||
| from unittest import mock, TestCase | |||
| @@ -0,0 +1,237 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test debugger grpc server. | |||
| Usage: | |||
| pytest tests/ut/debugger/test_debugger_grpc_server.py | |||
| """ | |||
| from unittest import mock | |||
| from unittest.mock import MagicMock | |||
| import numpy as np | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType | |||
| from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ | |||
| WatchpointHandler | |||
| from tests.ut.debugger.configurations import GRAPH_PROTO_FILE | |||
| class MockDataGenerator: | |||
| """Mocked Data generator.""" | |||
| @staticmethod | |||
| def get_run_cmd(steps=0, level='step', node_name=''): | |||
| """Get run command.""" | |||
| event = get_ack_reply() | |||
| event.run_cmd.run_level = level | |||
| if level == 'node': | |||
| event.run_cmd.node_name = node_name | |||
| else: | |||
| event.run_cmd.run_steps = steps | |||
| return event | |||
| @staticmethod | |||
| def get_exit_cmd(): | |||
| """Get exit command.""" | |||
| event = get_ack_reply() | |||
| event.exit = True | |||
| return event | |||
| @staticmethod | |||
| def get_set_cmd(): | |||
| """Get set command""" | |||
| event = get_ack_reply() | |||
| event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1)) | |||
| return event | |||
| @staticmethod | |||
| def get_view_cmd(): | |||
| """Get set command""" | |||
| view_event = get_ack_reply() | |||
| ms_tensor = view_event.view_cmd.tensors.add() | |||
| ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' | |||
| event = {'view_cmd': view_event, 'node_name': 'mock_node_name'} | |||
| return event | |||
| @staticmethod | |||
| def get_graph_chunks(): | |||
| """Get graph chunks.""" | |||
| chunk_size = 1024 | |||
| with open(GRAPH_PROTO_FILE, 'rb') as file_handler: | |||
| content = file_handler.read() | |||
| chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])] | |||
| return chunks | |||
| @staticmethod | |||
| def get_tensors(): | |||
| """Get tensors.""" | |||
| tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes() | |||
| tensor_pre = TensorProto( | |||
| node_name='mock_node_name', | |||
| slot='0', | |||
| data_type=DataType.DT_FLOAT32, | |||
| dims=[2, 3], | |||
| tensor_content=tensor_content[:12], | |||
| finished=0 | |||
| ) | |||
| tensor_succ = TensorProto() | |||
| tensor_succ.CopyFrom(tensor_pre) | |||
| tensor_succ.tensor_content = tensor_content[12:] | |||
| tensor_succ.finished = 1 | |||
| return [tensor_pre, tensor_succ] | |||
| @staticmethod | |||
| def get_watchpoint_hit(): | |||
| """Get watchpoint hit.""" | |||
| res = WatchpointHit(id=1) | |||
| res.tensor.node_name = 'mock_node_name' | |||
| res.tensor.slot = '0' | |||
| return res | |||
| class TestDebuggerGrpcServer: | |||
| """Test debugger grpc server.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Initialize for test class.""" | |||
| cls._server = None | |||
| def setup_method(self): | |||
| """Initialize for each testcase.""" | |||
| cache_store = DebuggerCache() | |||
| self._server = DebuggerGrpcServer(cache_store) | |||
| def test_waitcmd_with_pending_status(self): | |||
| """Test wait command interface when status is pending.""" | |||
| res = self._server.WaitCMD(MagicMock(), MagicMock()) | |||
| assert res.status == EventReply.Status.FAILED | |||
| @mock.patch.object(WatchpointHitHandler, 'empty', False) | |||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command') | |||
| def test_waitcmd_with_old_command(self, *args): | |||
| """Test wait command interface with old command.""" | |||
| old_command = MockDataGenerator.get_run_cmd(steps=1) | |||
| args[0].return_value = old_command | |||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||
| setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'}) | |||
| setattr(self._server, '_received_hit', True) | |||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||
| assert res == old_command | |||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) | |||
| @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') | |||
| def test_waitcmd_with_next_command(self, *args): | |||
| """Test wait for next command.""" | |||
| old_command = MockDataGenerator.get_run_cmd(steps=1) | |||
| args[0].return_value = old_command | |||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||
| assert res == old_command | |||
| @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) | |||
| @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') | |||
| def test_waitcmd_with_next_command_is_none(self, *args): | |||
| """Test wait command interface with next command is None.""" | |||
| args[0].return_value = None | |||
| setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH) | |||
| res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) | |||
| assert res == get_ack_reply(1) | |||
| @mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None)) | |||
| @mock.patch.object(DebuggerCache, 'has_command') | |||
| def test_deal_with_old_command_with_continue_steps(self, *args): | |||
| """Test deal with old command with continue steps.""" | |||
| args[0].side_effect = [True, False] | |||
| setattr(self._server, '_continue_steps', 1) | |||
| res = self._server._deal_with_old_command() | |||
| assert res == MockDataGenerator.get_run_cmd(steps=1) | |||
| @mock.patch.object(DebuggerCache, 'get_command') | |||
| @mock.patch.object(DebuggerCache, 'has_command', return_value=True) | |||
| def test_deal_with_old_command_with_exit_cmd(self, *args): | |||
| """Test deal with exit command.""" | |||
| cmd = MockDataGenerator.get_exit_cmd() | |||
| args[1].return_value = ('0', cmd) | |||
| res = self._server._deal_with_old_command() | |||
| assert res == cmd | |||
| @mock.patch.object(DebuggerCache, 'get_command') | |||
| @mock.patch.object(DebuggerCache, 'has_command', return_value=True) | |||
| def test_deal_with_old_command_with_view_cmd(self, *args): | |||
| """Test deal with view command.""" | |||
| cmd = MockDataGenerator.get_view_cmd() | |||
| args[1].return_value = ('0', cmd) | |||
| res = self._server._deal_with_old_command() | |||
| assert res == cmd.get('view_cmd') | |||
| expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True} | |||
| assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd | |||
| @mock.patch.object(DebuggerCache, 'get_command') | |||
| def test_wait_for_run_command(self, *args): | |||
| """Test wait for run command.""" | |||
| cmd = MockDataGenerator.get_run_cmd(steps=2) | |||
| args[0].return_value = ('0', cmd) | |||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||
| res = self._server._wait_for_next_command() | |||
| assert res == MockDataGenerator.get_run_cmd(steps=1) | |||
| assert getattr(self._server, '_continue_steps') == 1 | |||
| @mock.patch.object(DebuggerCache, 'get_command') | |||
| def test_wait_for_pause_and_run_command(self, *args): | |||
| """Test wait for run command.""" | |||
| pause_cmd = MockDataGenerator.get_run_cmd(steps=0) | |||
| empty_view_cmd = MockDataGenerator.get_view_cmd() | |||
| empty_view_cmd.pop('node_name') | |||
| run_cmd = MockDataGenerator.get_run_cmd(steps=2) | |||
| args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)] | |||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||
| res = self._server._wait_for_next_command() | |||
| assert res == run_cmd | |||
| assert getattr(self._server, '_continue_steps') == 1 | |||
| def test_send_matadata(self): | |||
| """Test SendMatadata interface.""" | |||
| res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock()) | |||
| assert res == get_ack_reply() | |||
| def test_send_matadata_with_training_done(self): | |||
| """Test SendMatadata interface.""" | |||
| res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock()) | |||
| assert res == get_ack_reply() | |||
| def test_send_graph(self): | |||
| """Test SendGraph interface.""" | |||
| res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock()) | |||
| assert res == get_ack_reply() | |||
| def test_send_tensors(self): | |||
| """Test SendTensors interface.""" | |||
| res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock()) | |||
| assert res == get_ack_reply() | |||
| @mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id') | |||
| @mock.patch.object(GraphHandler, 'get_node_name_by_full_name') | |||
| def test_send_watchpoint_hit(self, *args): | |||
| """Test SendWatchpointHits interface.""" | |||
| args[0].side_effect = [None, 'mock_full_name'] | |||
| watchpoint_hit = MockDataGenerator.get_watchpoint_hit() | |||
| res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock()) | |||
| assert res == get_ack_reply() | |||
| @@ -0,0 +1,250 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test debugger server. | |||
| Usage: | |||
| pytest tests/ut/debugger/test_debugger_server.py | |||
| """ | |||
| import signal | |||
| from threading import Thread | |||
| from unittest import mock | |||
| from unittest.mock import MagicMock | |||
| import grpc | |||
| import pytest | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | |||
| TensorHandler | |||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | |||
| class TestDebuggerServer: | |||
| """Test debugger server.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Initialize for test class.""" | |||
| cls._server = None | |||
| def setup_method(self): | |||
| """Prepare debugger server object.""" | |||
| self._server = DebuggerServer() | |||
| @mock.patch.object(signal, 'signal') | |||
| @mock.patch.object(Thread, 'join') | |||
| @mock.patch.object(Thread, 'start') | |||
| @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') | |||
| @mock.patch.object(grpc, 'server') | |||
| def test_stop_server(self, *args): | |||
| """Test stop debugger server.""" | |||
| mock_grpc_server_manager = MagicMock() | |||
| args[0].return_value = mock_grpc_server_manager | |||
| self._server.start() | |||
| self._server._stop_handler(MagicMock(), MagicMock()) | |||
| assert self._server.back_server is not None | |||
| assert self._server.grpc_server_manager == mock_grpc_server_manager | |||
| @mock.patch.object(DebuggerCache, 'get_data') | |||
| def test_poll_data(self, *args): | |||
| """Test poll data request.""" | |||
| mock_data = {'pos': 'mock_data'} | |||
| args[0].return_value = mock_data | |||
| res = self._server.poll_data('0') | |||
| assert res == mock_data | |||
| def test_poll_data_with_exept(self): | |||
| """Test poll data with wrong input.""" | |||
| with pytest.raises(DebuggerParamValueError, match='Pos should be string.'): | |||
| self._server.poll_data(1) | |||
| @mock.patch.object(GraphHandler, 'search_nodes') | |||
| def test_search(self, *args): | |||
| """Test search node.""" | |||
| mock_graph = {'nodes': ['mock_nodes']} | |||
| args[0].return_value = mock_graph | |||
| res = self._server.search('mock_name') | |||
| assert res == mock_graph | |||
| def test_tensor_comparision_with_wrong_status(self): | |||
| """Test tensor comparison with wrong status.""" | |||
| with pytest.raises( | |||
| DebuggerCompareTensorError, | |||
| match='Failed to compare tensors as the MindSpore is not in waiting state.'): | |||
| self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||
| def test_tensor_comparision_with_wrong_type(self, *args): | |||
| """Test tensor comparison with wrong type.""" | |||
| args[1].return_value = 'name_scope' | |||
| with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'): | |||
| self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') | |||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||
| @mock.patch.object(TensorHandler, 'get_tensors_diff') | |||
| def test_tensor_comparision(self, *args): | |||
| """Test tensor comparison""" | |||
| mock_diff_res = {'tensor_value': {}} | |||
| args[0].return_value = mock_diff_res | |||
| res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]') | |||
| assert res == mock_diff_res | |||
| def test_retrieve_with_pending(self): | |||
| """Test retrieve request in pending status.""" | |||
| res = self._server.retrieve(mode='all') | |||
| assert res.get('metadata', {}).get('state') == 'pending' | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| def test_retrieve_all(self): | |||
| """Test retrieve request.""" | |||
| res = self._server.retrieve(mode='all') | |||
| compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json') | |||
| def test_retrieve_with_invalid_mode(self): | |||
| """Test retrieve with invalid mode.""" | |||
| with pytest.raises(DebuggerParamValueError, match='Invalid mode.'): | |||
| self._server.retrieve(mode='invalid_mode') | |||
| @mock.patch.object(GraphHandler, 'get') | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope') | |||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') | |||
| def test_retrieve_node(self, *args): | |||
| """Test retrieve node information.""" | |||
| mock_graph = {'graph': {}} | |||
| args[2].return_value = mock_graph | |||
| res = self._server._retrieve_node({'name': 'mock_node_name'}) | |||
| assert res == mock_graph | |||
| def test_retrieve_tensor_history_with_pending(self): | |||
| """Test retrieve request in pending status.""" | |||
| res = self._server.retrieve_tensor_history('mock_node_name') | |||
| assert res.get('metadata', {}).get('state') == 'pending' | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_tensor_history') | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') | |||
| def test_retrieve_tensor_history(self, *args): | |||
| """Test retrieve tensor history.""" | |||
| args[1].return_value = mock_tensor_history() | |||
| res = self._server.retrieve_tensor_history('mock_node_name') | |||
| compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json') | |||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||
| def test_validate_leaf_name(self, *args): | |||
| """Test validate leaf name.""" | |||
| args[0].return_value = 'name_scope' | |||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | |||
| self._server._validate_leaf_name(node_name='mock_node_name') | |||
| @mock.patch.object(TensorHandler, 'get') | |||
| @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name') | |||
| def test_retrieve_tensor_value(self, *args): | |||
| """Test retrieve tensor value.""" | |||
| mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}} | |||
| args[0].return_value = ('Parameter', 'mock_node_name') | |||
| args[1].return_value = mock_tensor_value | |||
| res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]') | |||
| assert res == mock_tensor_value | |||
| @mock.patch.object(WatchpointHandler, 'get') | |||
| def test_retrieve_watchpoints(self, *args): | |||
| """Test retrieve watchpoints.""" | |||
| mock_watchpoint = {'watch_points': {}} | |||
| args[0].return_value = mock_watchpoint | |||
| res = self._server._retrieve_watchpoint({}) | |||
| assert res == mock_watchpoint | |||
| @mock.patch.object(DebuggerServer, '_retrieve_node') | |||
| def test_retrieve_watchpoint(self, *args): | |||
| """Test retrieve single watchpoint.""" | |||
| mock_watchpoint = {'nodes': {}} | |||
| args[0].return_value = mock_watchpoint | |||
| res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | |||
| assert res == mock_watchpoint | |||
| @mock.patch.object(DebuggerServer, '_validate_leaf_name') | |||
| @mock.patch.object(DebuggerServer, '_get_tensor_history') | |||
| @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) | |||
| def test_retrieve_watchpoint_hit(self, *args): | |||
| """Test retrieve single watchpoint.""" | |||
| args[1].return_value = {'tensor_history': {}} | |||
| res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) | |||
| assert res == {'tensor_history': {}, 'graph': {}} | |||
| def test_create_watchpoint_with_wrong_state(self): | |||
| """Test create watchpoint with wrong state.""" | |||
| with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | |||
| self._server.create_watchpoint(watch_condition={'condition': 'INF'}) | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') | |||
| @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') | |||
| @mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()]) | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | |||
| @mock.patch.object(WatchpointHandler, 'create_watchpoint') | |||
| def test_create_watchpoint(self, *args): | |||
| """Test create watchpoint.""" | |||
| args[0].return_value = 1 | |||
| res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name']) | |||
| assert res == {'id': 1} | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_searched_node_list') | |||
| @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id') | |||
| @mock.patch.object(WatchpointHandler, 'update_watchpoint') | |||
| def test_update_watchpoint(self, *args): | |||
| """Test update watchpoint.""" | |||
| args[2].return_value = [MagicMock(name='seatch_name/op_name')] | |||
| res = self._server.update_watchpoint( | |||
| watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name') | |||
| assert res == {} | |||
| def test_delete_watchpoint_with_wrong_state(self): | |||
| """Test delete watchpoint with wrong state.""" | |||
| with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'): | |||
| self._server.delete_watchpoint(watch_point_id=1) | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(WatchpointHandler, 'delete_watchpoint') | |||
| def test_delete_watchpoint(self, *args): | |||
| """Test delete watchpoint with wrong state.""" | |||
| args[0].return_value = None | |||
| res = self._server.delete_watchpoint(1) | |||
| assert res == {} | |||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||
| ('continue', 'waiting', 'running'), | |||
| ('pause', 'running', 'waiting'), | |||
| ('terminate', 'waiting', 'pending')]) | |||
| def test_control(self, mode, cur_state, state): | |||
| """Test control request.""" | |||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | |||
| res = self._server.control({'mode': mode}) | |||
| assert res == {'metadata': {'state': state}} | |||
| def test_construct_run_event(self): | |||
| """Test construct run event.""" | |||
| res = self._server._construct_run_event({'level': 'node'}) | |||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | |||