|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Function:
- Test debugger server.
- Usage:
- pytest tests/ut/debugger/test_debugger_server.py
- """
- import signal
- from threading import Thread
- from unittest import mock
- from unittest.mock import MagicMock
-
- import grpc
- import pytest
-
- from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
- DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError
- from mindinsight.debugger.debugger_cache import DebuggerCache
- from mindinsight.debugger.debugger_server import DebuggerServer
- from mindinsight.debugger.debugger_server import grpc_server_base
- from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
- from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
- TensorHandler
- from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history
-
-
- class TestDebuggerServer:
- """Test debugger server."""
-
- @classmethod
- def setup_class(cls):
- """Initialize for test class."""
- cls._server = None
-
- def setup_method(self):
- """Prepare debugger server object."""
- self._server = DebuggerServer()
-
- @mock.patch.object(signal, 'signal')
- @mock.patch.object(Thread, 'join')
- @mock.patch.object(Thread, 'start')
- @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server')
- @mock.patch.object(grpc, 'server')
- def test_stop_server(self, *args):
- """Test stop debugger server."""
- mock_grpc_server_manager = MagicMock()
- args[0].return_value = mock_grpc_server_manager
- self._server.start()
- self._server._stop_handler(MagicMock(), MagicMock())
- assert self._server.back_server is not None
- assert self._server.grpc_server_manager == mock_grpc_server_manager
-
- @mock.patch.object(DebuggerCache, 'get_data')
- def test_poll_data(self, *args):
- """Test poll data request."""
- mock_data = {'pos': 'mock_data'}
- args[0].return_value = mock_data
- res = self._server.poll_data('0')
- assert res == mock_data
-
- def test_poll_data_with_exept(self):
- """Test poll data with wrong input."""
- with pytest.raises(DebuggerParamValueError, match='Pos should be string.'):
- self._server.poll_data(1)
-
- @mock.patch.object(GraphHandler, 'search_nodes')
- def test_search(self, *args):
- """Test search node."""
- mock_graph = {'nodes': ['mock_nodes']}
- args[0].return_value = mock_graph
- res = self._server.search('mock_name')
- assert res == mock_graph
-
- def test_tensor_comparision_with_wrong_status(self):
- """Test tensor comparison with wrong status."""
- with pytest.raises(
- DebuggerCompareTensorError,
- match='Failed to compare tensors as the MindSpore is not in waiting state.'):
- self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(GraphHandler, 'get_node_type')
- @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
- def test_tensor_comparision_with_wrong_type(self, *args):
- """Test tensor comparison with wrong type."""
- args[1].return_value = 'name_scope'
- with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'):
- self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
- @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
- @mock.patch.object(TensorHandler, 'get_tensors_diff')
- def test_tensor_comparision(self, *args):
- """Test tensor comparison"""
- mock_diff_res = {'tensor_value': {}}
- args[0].return_value = mock_diff_res
- res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]')
- assert res == mock_diff_res
-
- def test_retrieve_with_pending(self):
- """Test retrieve request in pending status."""
- res = self._server.retrieve(mode='all')
- assert res.get('metadata', {}).get('state') == 'pending'
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- def test_retrieve_all(self):
- """Test retrieve request."""
- res = self._server.retrieve(mode='all')
- compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json')
-
- def test_retrieve_with_invalid_mode(self):
- """Test retrieve with invalid mode."""
- with pytest.raises(DebuggerParamValueError, match='Invalid mode.'):
- self._server.retrieve(mode='invalid_mode')
-
- @mock.patch.object(GraphHandler, 'get')
- @mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope')
- @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
- def test_retrieve_node(self, *args):
- """Test retrieve node information."""
- mock_graph = {'graph': {}}
- args[2].return_value = mock_graph
- res = self._server._retrieve_node({'name': 'mock_node_name'})
- assert res == mock_graph
-
- def test_retrieve_tensor_history_with_pending(self):
- """Test retrieve request in pending status."""
- res = self._server.retrieve_tensor_history('mock_node_name')
- assert res.get('metadata', {}).get('state') == 'pending'
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(GraphHandler, 'get_tensor_history')
- @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
- def test_retrieve_tensor_history(self, *args):
- """Test retrieve tensor history."""
- args[1].return_value = mock_tensor_history()
- res = self._server.retrieve_tensor_history('mock_node_name')
- compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json')
-
- @mock.patch.object(GraphHandler, 'get_node_type')
- def test_validate_leaf_name(self, *args):
- """Test validate leaf name."""
- args[0].return_value = 'name_scope'
- with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
- self._server._validate_leaf_name(node_name='mock_node_name')
-
- @mock.patch.object(TensorHandler, 'get')
- @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
- def test_retrieve_tensor_value(self, *args):
- """Test retrieve tensor value."""
- mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}}
- args[0].return_value = ('Parameter', 'mock_node_name')
- args[1].return_value = mock_tensor_value
- res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]')
- assert res == mock_tensor_value
-
- @mock.patch.object(WatchpointHandler, 'get')
- def test_retrieve_watchpoints(self, *args):
- """Test retrieve watchpoints."""
- mock_watchpoint = {'watch_points': {}}
- args[0].return_value = mock_watchpoint
- res = self._server._retrieve_watchpoint({})
- assert res == mock_watchpoint
-
- @mock.patch.object(DebuggerServer, '_retrieve_node')
- def test_retrieve_watchpoint(self, *args):
- """Test retrieve single watchpoint."""
- mock_watchpoint = {'nodes': {}}
- args[0].return_value = mock_watchpoint
- res = self._server._retrieve_watchpoint({'watch_point_id': 1})
- assert res == mock_watchpoint
-
- @mock.patch.object(DebuggerServer, '_validate_leaf_name')
- @mock.patch.object(DebuggerServer, '_get_tensor_history')
- @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
- def test_retrieve_watchpoint_hit(self, *args):
- """Test retrieve single watchpoint."""
- args[1].return_value = {'tensor_history': {}}
- res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True})
- assert res == {'tensor_history': {}, 'graph': {}}
-
- def test_create_watchpoint_with_wrong_state(self):
- """Test create watchpoint with wrong state."""
- with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):
- self._server.create_watchpoint(watch_condition={'condition': 'INF'})
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
- @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
- @mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()])
- @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
- @mock.patch.object(WatchpointHandler, 'create_watchpoint')
- def test_create_watchpoint(self, *args):
- """Test create watchpoint."""
- args[0].return_value = 1
- res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name'])
- assert res == {'id': 1}
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(GraphHandler, 'get_searched_node_list')
- @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id')
- @mock.patch.object(WatchpointHandler, 'update_watchpoint')
- def test_update_watchpoint(self, *args):
- """Test update watchpoint."""
- args[2].return_value = [MagicMock(name='seatch_name/op_name')]
- res = self._server.update_watchpoint(
- watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name')
- assert res == {}
-
- def test_delete_watchpoint_with_wrong_state(self):
- """Test delete watchpoint with wrong state."""
- with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'):
- self._server.delete_watchpoint(watch_point_id=1)
-
- @mock.patch.object(MetadataHandler, 'state', 'waiting')
- @mock.patch.object(WatchpointHandler, 'delete_watchpoint')
- def test_delete_watchpoint(self, *args):
- """Test delete watchpoint with wrong state."""
- args[0].return_value = None
- res = self._server.delete_watchpoint(1)
- assert res == {}
-
- @pytest.mark.parametrize('mode, cur_state, state', [
- ('continue', 'waiting', 'running'),
- ('pause', 'running', 'waiting'),
- ('terminate', 'waiting', 'pending')])
- def test_control(self, mode, cur_state, state):
- """Test control request."""
- with mock.patch.object(MetadataHandler, 'state', cur_state):
- res = self._server.control({'mode': mode})
- assert res == {'metadata': {'state': state}}
-
- def test_construct_run_event(self):
- """Test construct run event."""
- res = self._server._construct_run_event({'level': 'node'})
- assert res.run_cmd == RunCMD(run_level='node', node_name='')
|