You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_debugger_server.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test debugger server.
  18. Usage:
  19. pytest tests/ut/debugger/test_debugger_server.py
  20. """
  21. import signal
  22. from threading import Thread
  23. from unittest import mock
  24. from unittest.mock import MagicMock
  25. import grpc
  26. import pytest
  27. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  28. DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError
  29. from mindinsight.debugger.debugger_cache import DebuggerCache
  30. from mindinsight.debugger.debugger_server import DebuggerServer
  31. from mindinsight.debugger.debugger_server import grpc_server_base
  32. from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
  33. from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
  34. TensorHandler
  35. from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history
  36. class TestDebuggerServer:
  37. """Test debugger server."""
  38. @classmethod
  39. def setup_class(cls):
  40. """Initialize for test class."""
  41. cls._server = None
  42. def setup_method(self):
  43. """Prepare debugger server object."""
  44. self._server = DebuggerServer()
  45. @mock.patch.object(signal, 'signal')
  46. @mock.patch.object(Thread, 'join')
  47. @mock.patch.object(Thread, 'start')
  48. @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server')
  49. @mock.patch.object(grpc, 'server')
  50. def test_stop_server(self, *args):
  51. """Test stop debugger server."""
  52. mock_grpc_server_manager = MagicMock()
  53. args[0].return_value = mock_grpc_server_manager
  54. self._server.start()
  55. self._server._stop_handler(MagicMock(), MagicMock())
  56. assert self._server.back_server is not None
  57. assert self._server.grpc_server_manager == mock_grpc_server_manager
  58. @mock.patch.object(DebuggerCache, 'get_data')
  59. def test_poll_data(self, *args):
  60. """Test poll data request."""
  61. mock_data = {'pos': 'mock_data'}
  62. args[0].return_value = mock_data
  63. res = self._server.poll_data('0')
  64. assert res == mock_data
  65. def test_poll_data_with_exept(self):
  66. """Test poll data with wrong input."""
  67. with pytest.raises(DebuggerParamValueError, match='Pos should be string.'):
  68. self._server.poll_data(1)
  69. @mock.patch.object(GraphHandler, 'search_nodes')
  70. def test_search(self, *args):
  71. """Test search node."""
  72. mock_graph = {'nodes': ['mock_nodes']}
  73. args[0].return_value = mock_graph
  74. res = self._server.search('mock_name')
  75. assert res == mock_graph
  76. def test_tensor_comparision_with_wrong_status(self):
  77. """Test tensor comparison with wrong status."""
  78. with pytest.raises(
  79. DebuggerCompareTensorError,
  80. match='Failed to compare tensors as the MindSpore is not in waiting state.'):
  81. self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')
  82. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  83. @mock.patch.object(GraphHandler, 'get_node_type')
  84. @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
  85. def test_tensor_comparision_with_wrong_type(self, *args):
  86. """Test tensor comparison with wrong type."""
  87. args[1].return_value = 'name_scope'
  88. with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'):
  89. self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')
  90. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  91. @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
  92. @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
  93. @mock.patch.object(TensorHandler, 'get_tensors_diff')
  94. def test_tensor_comparision(self, *args):
  95. """Test tensor comparison"""
  96. mock_diff_res = {'tensor_value': {}}
  97. args[0].return_value = mock_diff_res
  98. res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]')
  99. assert res == mock_diff_res
  100. def test_retrieve_with_pending(self):
  101. """Test retrieve request in pending status."""
  102. res = self._server.retrieve(mode='all')
  103. assert res.get('metadata', {}).get('state') == 'pending'
  104. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  105. def test_retrieve_all(self):
  106. """Test retrieve request."""
  107. res = self._server.retrieve(mode='all')
  108. compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json')
  109. def test_retrieve_with_invalid_mode(self):
  110. """Test retrieve with invalid mode."""
  111. with pytest.raises(DebuggerParamValueError, match='Invalid mode.'):
  112. self._server.retrieve(mode='invalid_mode')
  113. @mock.patch.object(GraphHandler, 'get')
  114. @mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope')
  115. @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
  116. def test_retrieve_node(self, *args):
  117. """Test retrieve node information."""
  118. mock_graph = {'graph': {}}
  119. args[2].return_value = mock_graph
  120. res = self._server._retrieve_node({'name': 'mock_node_name'})
  121. assert res == mock_graph
  122. def test_retrieve_tensor_history_with_pending(self):
  123. """Test retrieve request in pending status."""
  124. res = self._server.retrieve_tensor_history('mock_node_name')
  125. assert res.get('metadata', {}).get('state') == 'pending'
  126. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  127. @mock.patch.object(GraphHandler, 'get_tensor_history')
  128. @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
  129. def test_retrieve_tensor_history(self, *args):
  130. """Test retrieve tensor history."""
  131. args[1].return_value = mock_tensor_history()
  132. res = self._server.retrieve_tensor_history('mock_node_name')
  133. compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json')
  134. @mock.patch.object(GraphHandler, 'get_node_type')
  135. def test_validate_leaf_name(self, *args):
  136. """Test validate leaf name."""
  137. args[0].return_value = 'name_scope'
  138. with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
  139. self._server._validate_leaf_name(node_name='mock_node_name')
  140. @mock.patch.object(TensorHandler, 'get')
  141. @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
  142. def test_retrieve_tensor_value(self, *args):
  143. """Test retrieve tensor value."""
  144. mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}}
  145. args[0].return_value = ('Parameter', 'mock_node_name')
  146. args[1].return_value = mock_tensor_value
  147. res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]')
  148. assert res == mock_tensor_value
  149. @mock.patch.object(WatchpointHandler, 'get')
  150. def test_retrieve_watchpoints(self, *args):
  151. """Test retrieve watchpoints."""
  152. mock_watchpoint = {'watch_points': {}}
  153. args[0].return_value = mock_watchpoint
  154. res = self._server._retrieve_watchpoint({})
  155. assert res == mock_watchpoint
  156. @mock.patch.object(DebuggerServer, '_retrieve_node')
  157. def test_retrieve_watchpoint(self, *args):
  158. """Test retrieve single watchpoint."""
  159. mock_watchpoint = {'nodes': {}}
  160. args[0].return_value = mock_watchpoint
  161. res = self._server._retrieve_watchpoint({'watch_point_id': 1})
  162. assert res == mock_watchpoint
  163. @mock.patch.object(DebuggerServer, '_validate_leaf_name')
  164. @mock.patch.object(DebuggerServer, '_get_tensor_history')
  165. @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
  166. def test_retrieve_watchpoint_hit(self, *args):
  167. """Test retrieve single watchpoint."""
  168. args[1].return_value = {'tensor_history': {}}
  169. res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True})
  170. assert res == {'tensor_history': {}, 'graph': {}}
  171. def test_create_watchpoint_with_wrong_state(self):
  172. """Test create watchpoint with wrong state."""
  173. with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):
  174. self._server.create_watchpoint(watch_condition={'condition': 'INF'})
  175. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  176. @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
  177. @mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
  178. @mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()])
  179. @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
  180. @mock.patch.object(WatchpointHandler, 'create_watchpoint')
  181. def test_create_watchpoint(self, *args):
  182. """Test create watchpoint."""
  183. args[0].return_value = 1
  184. res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name'])
  185. assert res == {'id': 1}
  186. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  187. @mock.patch.object(GraphHandler, 'get_searched_node_list')
  188. @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id')
  189. @mock.patch.object(WatchpointHandler, 'update_watchpoint')
  190. def test_update_watchpoint(self, *args):
  191. """Test update watchpoint."""
  192. args[2].return_value = [MagicMock(name='seatch_name/op_name')]
  193. res = self._server.update_watchpoint(
  194. watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name')
  195. assert res == {}
  196. def test_delete_watchpoint_with_wrong_state(self):
  197. """Test delete watchpoint with wrong state."""
  198. with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'):
  199. self._server.delete_watchpoint(watch_point_id=1)
  200. @mock.patch.object(MetadataHandler, 'state', 'waiting')
  201. @mock.patch.object(WatchpointHandler, 'delete_watchpoint')
  202. def test_delete_watchpoint(self, *args):
  203. """Test delete watchpoint with wrong state."""
  204. args[0].return_value = None
  205. res = self._server.delete_watchpoint(1)
  206. assert res == {}
  207. @pytest.mark.parametrize('mode, cur_state, state', [
  208. ('continue', 'waiting', 'running'),
  209. ('pause', 'running', 'waiting'),
  210. ('terminate', 'waiting', 'pending')])
  211. def test_control(self, mode, cur_state, state):
  212. """Test control request."""
  213. with mock.patch.object(MetadataHandler, 'state', cur_state):
  214. res = self._server.control({'mode': mode})
  215. assert res == {'metadata': {'state': state}}
  216. def test_construct_run_event(self):
  217. """Test construct run event."""
  218. res = self._server._construct_run_event({'level': 'node'})
  219. assert res.run_cmd == RunCMD(run_level='node', node_name='')