diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py index 3efbcf64..52ba5805 100644 --- a/mindinsight/debugger/common/exceptions/exceptions.py +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -78,7 +78,8 @@ class DebuggerCompareTensorError(MindInsightException): def __init__(self, msg): super(DebuggerCompareTensorError, self).__init__( error=DebuggerErrors.COMPARE_TENSOR_ERROR, - message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) + message=msg, + http_code=400 ) @@ -111,7 +112,8 @@ class DebuggerNodeNotInGraphError(MindInsightException): err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." super(DebuggerNodeNotInGraphError, self).__init__( error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, - message=err_msg + message=err_msg, + http_code=400 ) @@ -120,5 +122,6 @@ class DebuggerGraphNotExistError(MindInsightException): def __init__(self): super(DebuggerGraphNotExistError, self).__init__( error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, - message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value + message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value, + http_code=400 ) diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 2fae23d3..f7182e57 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -24,7 +24,8 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.utils.tools import to_float from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ - DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError + DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ + DebuggerCompareTensorError from mindinsight.debugger.common.log import logger as log from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo @@ -146,13 +147,10 @@ class DebuggerServer: node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) tolerance = to_float(tolerance, 'tolerance') tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) - if detail == 'data': - if node_type == NodeTypeEnum.PARAMETER.value: - reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) - else: - raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) + if node_type == NodeTypeEnum.PARAMETER.value: + reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) else: - raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) + raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) return reply def retrieve(self, mode, filter_condition=None): @@ -177,8 +175,8 @@ class DebuggerServer: # validate param if mode not in mode_mapping.keys(): log.error("Invalid param . should be in ['all', 'node', 'watchpoint', " - "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) - raise DebuggerParamTypeError("Invalid mode.") + "'watchpoint_hit'], but got %s.", mode_mapping) + raise DebuggerParamValueError("Invalid mode.") # validate backend status metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.state == ServerStatus.PENDING.value: diff --git a/tests/ut/debugger/__init__.py b/tests/ut/debugger/__init__.py index 95a33e57..c6e0e62c 100644 --- a/tests/ut/debugger/__init__.py +++ b/tests/ut/debugger/__init__.py @@ -12,4 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Test for debugger module.""" +""" +Function: + Unit test for debugger module. +Usage: + pytest tests/ut/debugger/ +""" diff --git a/tests/ut/debugger/configurations.py b/tests/ut/debugger/configurations.py index fd5ac870..c0426679 100644 --- a/tests/ut/debugger/configurations.py +++ b/tests/ut/debugger/configurations.py @@ -23,10 +23,12 @@ from mindinsight.debugger.common.utils import NodeBasicInfo from mindinsight.debugger.proto import ms_graph_pb2 from mindinsight.debugger.stream_handler.graph_handler import GraphHandler from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler +from tests.utils.tools import compare_result_with_file GRAPH_PROTO_FILE = os.path.join( os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb' ) +DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expected_results') def get_graph_proto(): @@ -137,3 +139,15 @@ def mock_tensor_history(): } return tensor_history + + +def compare_debugger_result_with_file(res, expect_file): + """ + Compare debugger result with file. + + Args: + res (dict): The debugger result in dict type. + expect_file: The expected file name. + """ + real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, expect_file) + compare_result_with_file(res, real_path) diff --git a/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json b/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json new file mode 100644 index 00000000..81cad5eb --- /dev/null +++ b/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json @@ -0,0 +1 @@ +{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}, "graph": {}, "watch_points": []} \ No newline at end of file diff --git a/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json b/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json new file mode 100644 index 00000000..1beb6696 --- /dev/null +++ b/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json @@ -0,0 +1 @@ +{"tensor_history": [{"name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", "node_type": "TransData", "type": "output", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}, {"name": "Default/args0:0", "full_name": "Default/args0:0", "node_type": "Parameter", "type": "input", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}], "metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}} \ No newline at end of file diff --git a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py index 67e1f6ec..27d7e737 100644 --- a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py +++ b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Test WatchpointHandler.""" +""" +Function: + Test query debugger watchpoint handler. +Usage: + pytest tests/ut/debugger +""" import json import os from unittest import mock, TestCase diff --git a/tests/ut/debugger/test_debugger_grpc_server.py b/tests/ut/debugger/test_debugger_grpc_server.py new file mode 100644 index 00000000..bd20fff7 --- /dev/null +++ b/tests/ut/debugger/test_debugger_grpc_server.py @@ -0,0 +1,237 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + Test debugger grpc server. +Usage: + pytest tests/ut/debugger/test_debugger_grpc_server.py +""" +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np + +from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus +from mindinsight.debugger.debugger_cache import DebuggerCache +from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer +from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit +from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType +from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ + WatchpointHandler +from tests.ut.debugger.configurations import GRAPH_PROTO_FILE + + +class MockDataGenerator: + """Mocked Data generator.""" + + @staticmethod + def get_run_cmd(steps=0, level='step', node_name=''): + """Get run command.""" + event = get_ack_reply() + event.run_cmd.run_level = level + if level == 'node': + event.run_cmd.node_name = node_name + else: + event.run_cmd.run_steps = steps + return event + + @staticmethod + def get_exit_cmd(): + """Get exit command.""" + event = get_ack_reply() + event.exit = True + return event + + @staticmethod + def get_set_cmd(): + """Get set command""" + event = get_ack_reply() + event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1)) + return event + + @staticmethod + def get_view_cmd(): + """Get set command""" + view_event = get_ack_reply() + ms_tensor = view_event.view_cmd.tensors.add() + ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' + event = {'view_cmd': view_event, 'node_name': 'mock_node_name'} + return event + + @staticmethod + def get_graph_chunks(): + """Get graph chunks.""" + chunk_size = 1024 + with open(GRAPH_PROTO_FILE, 'rb') as file_handler: + content = file_handler.read() + chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])] + return chunks + + @staticmethod + def get_tensors(): + """Get tensors.""" + tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes() + tensor_pre = TensorProto( + node_name='mock_node_name', + slot='0', + data_type=DataType.DT_FLOAT32, + dims=[2, 3], + tensor_content=tensor_content[:12], + finished=0 + ) + tensor_succ = TensorProto() + tensor_succ.CopyFrom(tensor_pre) + tensor_succ.tensor_content = tensor_content[12:] + tensor_succ.finished = 1 + return [tensor_pre, tensor_succ] + + @staticmethod + def get_watchpoint_hit(): + """Get watchpoint hit.""" + res = WatchpointHit(id=1) + res.tensor.node_name = 'mock_node_name' + res.tensor.slot = '0' + return res + + +class TestDebuggerGrpcServer: + """Test debugger grpc server.""" + + @classmethod + def setup_class(cls): + """Initialize for test class.""" + cls._server = None + + def setup_method(self): + """Initialize for each testcase.""" + cache_store = DebuggerCache() + self._server = DebuggerGrpcServer(cache_store) + + def test_waitcmd_with_pending_status(self): + """Test wait command interface when status is pending.""" + res = self._server.WaitCMD(MagicMock(), MagicMock()) + assert res.status == EventReply.Status.FAILED + + @mock.patch.object(WatchpointHitHandler, 'empty', False) + @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command') + def test_waitcmd_with_old_command(self, *args): + """Test wait command interface with old command.""" + old_command = MockDataGenerator.get_run_cmd(steps=1) + args[0].return_value = old_command + setattr(self._server, '_status', ServerStatus.WAITING) + setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'}) + setattr(self._server, '_received_hit', True) + res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) + assert res == old_command + + @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) + @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') + def test_waitcmd_with_next_command(self, *args): + """Test wait for next command.""" + old_command = MockDataGenerator.get_run_cmd(steps=1) + args[0].return_value = old_command + setattr(self._server, '_status', ServerStatus.WAITING) + res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) + assert res == old_command + + @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None) + @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command') + def test_waitcmd_with_next_command_is_none(self, *args): + """Test wait command interface with next command is None.""" + args[0].return_value = None + setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH) + res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock()) + assert res == get_ack_reply(1) + + @mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None)) + @mock.patch.object(DebuggerCache, 'has_command') + def test_deal_with_old_command_with_continue_steps(self, *args): + """Test deal with old command with continue steps.""" + args[0].side_effect = [True, False] + setattr(self._server, '_continue_steps', 1) + res = self._server._deal_with_old_command() + assert res == MockDataGenerator.get_run_cmd(steps=1) + + @mock.patch.object(DebuggerCache, 'get_command') + @mock.patch.object(DebuggerCache, 'has_command', return_value=True) + def test_deal_with_old_command_with_exit_cmd(self, *args): + """Test deal with exit command.""" + cmd = MockDataGenerator.get_exit_cmd() + args[1].return_value = ('0', cmd) + res = self._server._deal_with_old_command() + assert res == cmd + + @mock.patch.object(DebuggerCache, 'get_command') + @mock.patch.object(DebuggerCache, 'has_command', return_value=True) + def test_deal_with_old_command_with_view_cmd(self, *args): + """Test deal with view command.""" + cmd = MockDataGenerator.get_view_cmd() + args[1].return_value = ('0', cmd) + res = self._server._deal_with_old_command() + assert res == cmd.get('view_cmd') + expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True} + assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd + + @mock.patch.object(DebuggerCache, 'get_command') + def test_wait_for_run_command(self, *args): + """Test wait for run command.""" + cmd = MockDataGenerator.get_run_cmd(steps=2) + args[0].return_value = ('0', cmd) + setattr(self._server, '_status', ServerStatus.WAITING) + res = self._server._wait_for_next_command() + assert res == MockDataGenerator.get_run_cmd(steps=1) + assert getattr(self._server, '_continue_steps') == 1 + + @mock.patch.object(DebuggerCache, 'get_command') + def test_wait_for_pause_and_run_command(self, *args): + """Test wait for run command.""" + pause_cmd = MockDataGenerator.get_run_cmd(steps=0) + empty_view_cmd = MockDataGenerator.get_view_cmd() + empty_view_cmd.pop('node_name') + run_cmd = MockDataGenerator.get_run_cmd(steps=2) + args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)] + setattr(self._server, '_status', ServerStatus.WAITING) + res = self._server._wait_for_next_command() + assert res == run_cmd + assert getattr(self._server, '_continue_steps') == 1 + + def test_send_matadata(self): + """Test SendMatadata interface.""" + res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock()) + assert res == get_ack_reply() + + def test_send_matadata_with_training_done(self): + """Test SendMatadata interface.""" + res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock()) + assert res == get_ack_reply() + + def test_send_graph(self): + """Test SendGraph interface.""" + res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock()) + assert res == get_ack_reply() + + def test_send_tensors(self): + """Test SendTensors interface.""" + res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock()) + assert res == get_ack_reply() + + @mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id') + @mock.patch.object(GraphHandler, 'get_node_name_by_full_name') + def test_send_watchpoint_hit(self, *args): + """Test SendWatchpointHits interface.""" + args[0].side_effect = [None, 'mock_full_name'] + watchpoint_hit = MockDataGenerator.get_watchpoint_hit() + res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock()) + assert res == get_ack_reply() diff --git a/tests/ut/debugger/test_debugger_server.py b/tests/ut/debugger/test_debugger_server.py new file mode 100644 index 00000000..43bccc8e --- /dev/null +++ b/tests/ut/debugger/test_debugger_server.py @@ -0,0 +1,250 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + Test debugger server. +Usage: + pytest tests/ut/debugger/test_debugger_server.py +""" +import signal +from threading import Thread +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ + DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError +from mindinsight.debugger.debugger_cache import DebuggerCache +from mindinsight.debugger.debugger_server import DebuggerServer +from mindinsight.debugger.debugger_server import grpc_server_base +from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD +from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ + TensorHandler +from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history + + +class TestDebuggerServer: + """Test debugger server.""" + + @classmethod + def setup_class(cls): + """Initialize for test class.""" + cls._server = None + + def setup_method(self): + """Prepare debugger server object.""" + self._server = DebuggerServer() + + @mock.patch.object(signal, 'signal') + @mock.patch.object(Thread, 'join') + @mock.patch.object(Thread, 'start') + @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') + @mock.patch.object(grpc, 'server') + def test_stop_server(self, *args): + """Test stop debugger server.""" + mock_grpc_server_manager = MagicMock() + args[0].return_value = mock_grpc_server_manager + self._server.start() + self._server._stop_handler(MagicMock(), MagicMock()) + assert self._server.back_server is not None + assert self._server.grpc_server_manager == mock_grpc_server_manager + + @mock.patch.object(DebuggerCache, 'get_data') + def test_poll_data(self, *args): + """Test poll data request.""" + mock_data = {'pos': 'mock_data'} + args[0].return_value = mock_data + res = self._server.poll_data('0') + assert res == mock_data + + def test_poll_data_with_exept(self): + """Test poll data with wrong input.""" + with pytest.raises(DebuggerParamValueError, match='Pos should be string.'): + self._server.poll_data(1) + + @mock.patch.object(GraphHandler, 'search_nodes') + def test_search(self, *args): + """Test search node.""" + mock_graph = {'nodes': ['mock_nodes']} + args[0].return_value = mock_graph + res = self._server.search('mock_name') + assert res == mock_graph + + def test_tensor_comparision_with_wrong_status(self): + """Test tensor comparison with wrong status.""" + with pytest.raises( + DebuggerCompareTensorError, + match='Failed to compare tensors as the MindSpore is not in waiting state.'): + self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(GraphHandler, 'get_node_type') + @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') + def test_tensor_comparision_with_wrong_type(self, *args): + """Test tensor comparison with wrong type.""" + args[1].return_value = 'name_scope' + with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'): + self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]') + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') + @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') + @mock.patch.object(TensorHandler, 'get_tensors_diff') + def test_tensor_comparision(self, *args): + """Test tensor comparison""" + mock_diff_res = {'tensor_value': {}} + args[0].return_value = mock_diff_res + res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]') + assert res == mock_diff_res + + def test_retrieve_with_pending(self): + """Test retrieve request in pending status.""" + res = self._server.retrieve(mode='all') + assert res.get('metadata', {}).get('state') == 'pending' + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + def test_retrieve_all(self): + """Test retrieve request.""" + res = self._server.retrieve(mode='all') + compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json') + + def test_retrieve_with_invalid_mode(self): + """Test retrieve with invalid mode.""" + with pytest.raises(DebuggerParamValueError, match='Invalid mode.'): + self._server.retrieve(mode='invalid_mode') + + @mock.patch.object(GraphHandler, 'get') + @mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope') + @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name') + def test_retrieve_node(self, *args): + """Test retrieve node information.""" + mock_graph = {'graph': {}} + args[2].return_value = mock_graph + res = self._server._retrieve_node({'name': 'mock_node_name'}) + assert res == mock_graph + + def test_retrieve_tensor_history_with_pending(self): + """Test retrieve request in pending status.""" + res = self._server.retrieve_tensor_history('mock_node_name') + assert res.get('metadata', {}).get('state') == 'pending' + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(GraphHandler, 'get_tensor_history') + @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter') + def test_retrieve_tensor_history(self, *args): + """Test retrieve tensor history.""" + args[1].return_value = mock_tensor_history() + res = self._server.retrieve_tensor_history('mock_node_name') + compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json') + + @mock.patch.object(GraphHandler, 'get_node_type') + def test_validate_leaf_name(self, *args): + """Test validate leaf name.""" + args[0].return_value = 'name_scope' + with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): + self._server._validate_leaf_name(node_name='mock_node_name') + + @mock.patch.object(TensorHandler, 'get') + @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name') + def test_retrieve_tensor_value(self, *args): + """Test retrieve tensor value.""" + mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}} + args[0].return_value = ('Parameter', 'mock_node_name') + args[1].return_value = mock_tensor_value + res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]') + assert res == mock_tensor_value + + @mock.patch.object(WatchpointHandler, 'get') + def test_retrieve_watchpoints(self, *args): + """Test retrieve watchpoints.""" + mock_watchpoint = {'watch_points': {}} + args[0].return_value = mock_watchpoint + res = self._server._retrieve_watchpoint({}) + assert res == mock_watchpoint + + @mock.patch.object(DebuggerServer, '_retrieve_node') + def test_retrieve_watchpoint(self, *args): + """Test retrieve single watchpoint.""" + mock_watchpoint = {'nodes': {}} + args[0].return_value = mock_watchpoint + res = self._server._retrieve_watchpoint({'watch_point_id': 1}) + assert res == mock_watchpoint + + @mock.patch.object(DebuggerServer, '_validate_leaf_name') + @mock.patch.object(DebuggerServer, '_get_tensor_history') + @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) + def test_retrieve_watchpoint_hit(self, *args): + """Test retrieve single watchpoint.""" + args[1].return_value = {'tensor_history': {}} + res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) + assert res == {'tensor_history': {}, 'graph': {}} + + def test_create_watchpoint_with_wrong_state(self): + """Test create watchpoint with wrong state.""" + with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): + self._server.create_watchpoint(watch_condition={'condition': 'INF'}) + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') + @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name') + @mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()]) + @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') + @mock.patch.object(WatchpointHandler, 'create_watchpoint') + def test_create_watchpoint(self, *args): + """Test create watchpoint.""" + args[0].return_value = 1 + res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name']) + assert res == {'id': 1} + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(GraphHandler, 'get_searched_node_list') + @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id') + @mock.patch.object(WatchpointHandler, 'update_watchpoint') + def test_update_watchpoint(self, *args): + """Test update watchpoint.""" + args[2].return_value = [MagicMock(name='seatch_name/op_name')] + res = self._server.update_watchpoint( + watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name') + assert res == {} + + def test_delete_watchpoint_with_wrong_state(self): + """Test delete watchpoint with wrong state.""" + with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'): + self._server.delete_watchpoint(watch_point_id=1) + + @mock.patch.object(MetadataHandler, 'state', 'waiting') + @mock.patch.object(WatchpointHandler, 'delete_watchpoint') + def test_delete_watchpoint(self, *args): + """Test delete watchpoint with wrong state.""" + args[0].return_value = None + res = self._server.delete_watchpoint(1) + assert res == {} + + @pytest.mark.parametrize('mode, cur_state, state', [ + ('continue', 'waiting', 'running'), + ('pause', 'running', 'waiting'), + ('terminate', 'waiting', 'pending')]) + def test_control(self, mode, cur_state, state): + """Test control request.""" + with mock.patch.object(MetadataHandler, 'state', cur_state): + res = self._server.control({'mode': mode}) + assert res == {'metadata': {'state': state}} + + def test_construct_run_event(self): + """Test construct run event.""" + res = self._server._construct_run_event({'level': 'node'}) + assert res.run_cmd == RunCMD(run_level='node', node_name='')