You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_debugger_grpc_server.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test debugger grpc server.
  18. Usage:
  19. pytest tests/ut/debugger/test_debugger_grpc_server.py
  20. """
  21. from unittest import mock
  22. from unittest.mock import MagicMock
  23. import numpy as np
  24. from mindinsight.conditionmgr.conditionmgr import ConditionMgr
  25. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus
  26. from mindinsight.debugger.debugger_cache import DebuggerCache
  27. from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
  28. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit
  29. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  30. from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \
  31. WatchpointHandler
  32. from tests.ut.debugger.configurations import GRAPH_PROTO_FILE
  33. class MockDataGenerator:
  34. """Mocked Data generator."""
  35. @staticmethod
  36. def get_run_cmd(steps=0, level='step', node_name=''):
  37. """Get run command."""
  38. event = get_ack_reply()
  39. event.run_cmd.run_level = level
  40. if level == 'node':
  41. event.run_cmd.node_name = node_name
  42. else:
  43. event.run_cmd.run_steps = steps
  44. return event
  45. @staticmethod
  46. def get_exit_cmd():
  47. """Get exit command."""
  48. event = get_ack_reply()
  49. event.exit = True
  50. return event
  51. @staticmethod
  52. def get_set_cmd():
  53. """Get set command"""
  54. event = get_ack_reply()
  55. event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1))
  56. return event
  57. @staticmethod
  58. def get_view_cmd():
  59. """Get set command"""
  60. view_event = get_ack_reply()
  61. ms_tensor = view_event.view_cmd.tensors.add()
  62. ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0'
  63. event = {'view_cmd': view_event, 'node_name': 'mock_node_name'}
  64. return event
  65. @staticmethod
  66. def get_graph_chunks():
  67. """Get graph chunks."""
  68. chunk_size = 1024
  69. with open(GRAPH_PROTO_FILE, 'rb') as file_handler:
  70. content = file_handler.read()
  71. chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])]
  72. return chunks
  73. @staticmethod
  74. def get_tensors():
  75. """Get tensors."""
  76. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  77. tensor_pre = TensorProto(
  78. node_name='mock_node_name',
  79. slot='0',
  80. data_type=DataType.DT_FLOAT32,
  81. dims=[2, 3],
  82. tensor_content=tensor_content[:12],
  83. finished=0
  84. )
  85. tensor_succ = TensorProto()
  86. tensor_succ.CopyFrom(tensor_pre)
  87. tensor_succ.tensor_content = tensor_content[12:]
  88. tensor_succ.finished = 1
  89. return [tensor_pre, tensor_succ]
  90. @staticmethod
  91. def get_watchpoint_hit():
  92. """Get watchpoint hit."""
  93. res = WatchpointHit(id=1)
  94. res.tensor.node_name = 'mock_node_name'
  95. res.tensor.slot = '0'
  96. return res
  97. class TestDebuggerGrpcServer:
  98. """Test debugger grpc server."""
  99. @classmethod
  100. def setup_class(cls):
  101. """Initialize for test class."""
  102. cls._server = None
  103. def setup_method(self):
  104. """Initialize for each testcase."""
  105. cache_store = DebuggerCache()
  106. self._server = DebuggerGrpcServer(cache_store, condition_mgr=ConditionMgr())
  107. def test_waitcmd_with_pending_status(self):
  108. """Test wait command interface when status is pending."""
  109. res = self._server.WaitCMD(MagicMock(), MagicMock())
  110. assert res.status == EventReply.Status.FAILED
  111. @mock.patch.object(WatchpointHitHandler, 'empty', False)
  112. @mock.patch.object(WatchpointHitHandler, 'put')
  113. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command')
  114. def test_waitcmd_with_old_command(self, *args):
  115. """Test wait command interface with old command."""
  116. old_command = MockDataGenerator.get_run_cmd(steps=1)
  117. args[0].return_value = old_command
  118. setattr(self._server, '_status', ServerStatus.WAITING)
  119. setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'})
  120. setattr(self._server, '_received_hit', [MagicMock()])
  121. res = self._server.WaitCMD(MagicMock(cur_step=1, cur_node=''), MagicMock())
  122. assert res == old_command
  123. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
  124. @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
  125. def test_waitcmd_with_next_command(self, *args):
  126. """Test wait for next command."""
  127. old_command = MockDataGenerator.get_run_cmd(steps=1)
  128. args[0].return_value = old_command
  129. setattr(self._server, '_status', ServerStatus.WAITING)
  130. res = self._server.WaitCMD(MagicMock(cur_step=1, cur_node=''), MagicMock())
  131. assert res == old_command
  132. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
  133. @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
  134. def test_waitcmd_with_next_command_is_none(self, *args):
  135. """Test wait command interface with next command is None."""
  136. args[0].return_value = None
  137. setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH)
  138. res = self._server.WaitCMD(MagicMock(cur_step=1, cur_node=''), MagicMock())
  139. assert res == get_ack_reply(1)
  140. @mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None))
  141. @mock.patch.object(DebuggerCache, 'has_command')
  142. def test_deal_with_old_command_with_continue_steps(self, *args):
  143. """Test deal with old command with continue steps."""
  144. args[0].side_effect = [True, False]
  145. setattr(self._server, '_old_run_cmd', {'left_step_count': 1})
  146. res = self._server._deal_with_old_command()
  147. assert res == MockDataGenerator.get_run_cmd(steps=1)
  148. @mock.patch.object(DebuggerCache, 'get_command')
  149. @mock.patch.object(DebuggerCache, 'has_command', return_value=True)
  150. def test_deal_with_old_command_with_exit_cmd(self, *args):
  151. """Test deal with exit command."""
  152. cmd = MockDataGenerator.get_exit_cmd()
  153. args[1].return_value = ('0', cmd)
  154. res = self._server._deal_with_old_command()
  155. assert res == cmd
  156. @mock.patch.object(DebuggerCache, 'get_command')
  157. @mock.patch.object(DebuggerCache, 'has_command', return_value=True)
  158. def test_deal_with_old_command_with_view_cmd(self, *args):
  159. """Test deal with view command."""
  160. cmd = MockDataGenerator.get_view_cmd()
  161. args[1].return_value = ('0', cmd)
  162. res = self._server._deal_with_old_command()
  163. assert res == cmd.get('view_cmd')
  164. expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True}
  165. assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd
  166. @mock.patch.object(DebuggerCache, 'get_command')
  167. def test_wait_for_run_command(self, *args):
  168. """Test wait for run command."""
  169. cmd = MockDataGenerator.get_run_cmd(steps=2)
  170. args[0].return_value = ('0', cmd)
  171. setattr(self._server, '_status', ServerStatus.WAITING)
  172. res = self._server._wait_for_next_command()
  173. assert res == MockDataGenerator.get_run_cmd(steps=1)
  174. assert getattr(self._server, '_old_run_cmd') == {'left_step_count': 1}
  175. @mock.patch.object(DebuggerCache, 'get_command')
  176. def test_wait_for_pause_and_run_command(self, *args):
  177. """Test wait for run command."""
  178. pause_cmd = MockDataGenerator.get_run_cmd(steps=0)
  179. empty_view_cmd = MockDataGenerator.get_view_cmd()
  180. empty_view_cmd.pop('node_name')
  181. run_cmd = MockDataGenerator.get_run_cmd(steps=2)
  182. args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)]
  183. setattr(self._server, '_status', ServerStatus.WAITING)
  184. res = self._server._wait_for_next_command()
  185. assert res == run_cmd
  186. assert getattr(self._server, '_old_run_cmd') == {'left_step_count': 1}
  187. def test_send_matadata(self):
  188. """Test SendMatadata interface."""
  189. res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock())
  190. assert res == get_ack_reply()
  191. def test_send_matadata_with_training_done(self):
  192. """Test SendMatadata interface."""
  193. res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock())
  194. assert res == get_ack_reply()
  195. def test_send_graph(self):
  196. """Test SendGraph interface."""
  197. res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock())
  198. assert res == get_ack_reply()
  199. def test_send_tensors(self):
  200. """Test SendTensors interface."""
  201. res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock())
  202. assert res == get_ack_reply()
  203. @mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id')
  204. @mock.patch.object(GraphHandler, 'get_graph_id_by_full_name', return_value='mock_graph_name')
  205. @mock.patch.object(GraphHandler, 'get_node_name_by_full_name')
  206. def test_send_watchpoint_hit(self, *args):
  207. """Test SendWatchpointHits interface."""
  208. args[0].side_effect = [None, 'mock_full_name']
  209. watchpoint_hit = MockDataGenerator.get_watchpoint_hit()
  210. res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock())
  211. assert res == get_ack_reply()