You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_debugger_grpc_server.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test debugger grpc server.
  18. Usage:
  19. pytest tests/ut/debugger/test_debugger_grpc_server.py
  20. """
  21. from unittest import mock
  22. from unittest.mock import MagicMock
  23. import numpy as np
  24. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus
  25. from mindinsight.debugger.debugger_cache import DebuggerCache
  26. from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
  27. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit
  28. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  29. from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \
  30. WatchpointHandler
  31. from tests.ut.debugger.configurations import GRAPH_PROTO_FILE
  32. class MockDataGenerator:
  33. """Mocked Data generator."""
  34. @staticmethod
  35. def get_run_cmd(steps=0, level='step', node_name=''):
  36. """Get run command."""
  37. event = get_ack_reply()
  38. event.run_cmd.run_level = level
  39. if level == 'node':
  40. event.run_cmd.node_name = node_name
  41. else:
  42. event.run_cmd.run_steps = steps
  43. return event
  44. @staticmethod
  45. def get_exit_cmd():
  46. """Get exit command."""
  47. event = get_ack_reply()
  48. event.exit = True
  49. return event
  50. @staticmethod
  51. def get_set_cmd():
  52. """Get set command"""
  53. event = get_ack_reply()
  54. event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1))
  55. return event
  56. @staticmethod
  57. def get_view_cmd():
  58. """Get set command"""
  59. view_event = get_ack_reply()
  60. ms_tensor = view_event.view_cmd.tensors.add()
  61. ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0'
  62. event = {'view_cmd': view_event, 'node_name': 'mock_node_name'}
  63. return event
  64. @staticmethod
  65. def get_graph_chunks():
  66. """Get graph chunks."""
  67. chunk_size = 1024
  68. with open(GRAPH_PROTO_FILE, 'rb') as file_handler:
  69. content = file_handler.read()
  70. chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])]
  71. return chunks
  72. @staticmethod
  73. def get_tensors():
  74. """Get tensors."""
  75. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  76. tensor_pre = TensorProto(
  77. node_name='mock_node_name',
  78. slot='0',
  79. data_type=DataType.DT_FLOAT32,
  80. dims=[2, 3],
  81. tensor_content=tensor_content[:12],
  82. finished=0
  83. )
  84. tensor_succ = TensorProto()
  85. tensor_succ.CopyFrom(tensor_pre)
  86. tensor_succ.tensor_content = tensor_content[12:]
  87. tensor_succ.finished = 1
  88. return [tensor_pre, tensor_succ]
  89. @staticmethod
  90. def get_watchpoint_hit():
  91. """Get watchpoint hit."""
  92. res = WatchpointHit(id=1)
  93. res.tensor.node_name = 'mock_node_name'
  94. res.tensor.slot = '0'
  95. return res
  96. class TestDebuggerGrpcServer:
  97. """Test debugger grpc server."""
  98. @classmethod
  99. def setup_class(cls):
  100. """Initialize for test class."""
  101. cls._server = None
  102. def setup_method(self):
  103. """Initialize for each testcase."""
  104. cache_store = DebuggerCache()
  105. self._server = DebuggerGrpcServer(cache_store)
  106. def test_waitcmd_with_pending_status(self):
  107. """Test wait command interface when status is pending."""
  108. res = self._server.WaitCMD(MagicMock(), MagicMock())
  109. assert res.status == EventReply.Status.FAILED
  110. @mock.patch.object(WatchpointHitHandler, 'empty', False)
  111. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command')
  112. def test_waitcmd_with_old_command(self, *args):
  113. """Test wait command interface with old command."""
  114. old_command = MockDataGenerator.get_run_cmd(steps=1)
  115. args[0].return_value = old_command
  116. setattr(self._server, '_status', ServerStatus.WAITING)
  117. setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'})
  118. setattr(self._server, '_received_hit', True)
  119. res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
  120. assert res == old_command
  121. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
  122. @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
  123. def test_waitcmd_with_next_command(self, *args):
  124. """Test wait for next command."""
  125. old_command = MockDataGenerator.get_run_cmd(steps=1)
  126. args[0].return_value = old_command
  127. setattr(self._server, '_status', ServerStatus.WAITING)
  128. res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
  129. assert res == old_command
  130. @mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
  131. @mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
  132. def test_waitcmd_with_next_command_is_none(self, *args):
  133. """Test wait command interface with next command is None."""
  134. args[0].return_value = None
  135. setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH)
  136. res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
  137. assert res == get_ack_reply(1)
  138. @mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None))
  139. @mock.patch.object(DebuggerCache, 'has_command')
  140. def test_deal_with_old_command_with_continue_steps(self, *args):
  141. """Test deal with old command with continue steps."""
  142. args[0].side_effect = [True, False]
  143. setattr(self._server, '_old_run_cmd', {'left_step_count': 1})
  144. res = self._server._deal_with_old_command()
  145. assert res == MockDataGenerator.get_run_cmd(steps=1)
  146. @mock.patch.object(DebuggerCache, 'get_command')
  147. @mock.patch.object(DebuggerCache, 'has_command', return_value=True)
  148. def test_deal_with_old_command_with_exit_cmd(self, *args):
  149. """Test deal with exit command."""
  150. cmd = MockDataGenerator.get_exit_cmd()
  151. args[1].return_value = ('0', cmd)
  152. res = self._server._deal_with_old_command()
  153. assert res == cmd
  154. @mock.patch.object(DebuggerCache, 'get_command')
  155. @mock.patch.object(DebuggerCache, 'has_command', return_value=True)
  156. def test_deal_with_old_command_with_view_cmd(self, *args):
  157. """Test deal with view command."""
  158. cmd = MockDataGenerator.get_view_cmd()
  159. args[1].return_value = ('0', cmd)
  160. res = self._server._deal_with_old_command()
  161. assert res == cmd.get('view_cmd')
  162. expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True}
  163. assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd
  164. @mock.patch.object(DebuggerCache, 'get_command')
  165. def test_wait_for_run_command(self, *args):
  166. """Test wait for run command."""
  167. cmd = MockDataGenerator.get_run_cmd(steps=2)
  168. args[0].return_value = ('0', cmd)
  169. setattr(self._server, '_status', ServerStatus.WAITING)
  170. res = self._server._wait_for_next_command()
  171. assert res == MockDataGenerator.get_run_cmd(steps=1)
  172. assert getattr(self._server, '_old_run_cmd') == {'left_step_count': 1}
  173. @mock.patch.object(DebuggerCache, 'get_command')
  174. def test_wait_for_pause_and_run_command(self, *args):
  175. """Test wait for run command."""
  176. pause_cmd = MockDataGenerator.get_run_cmd(steps=0)
  177. empty_view_cmd = MockDataGenerator.get_view_cmd()
  178. empty_view_cmd.pop('node_name')
  179. run_cmd = MockDataGenerator.get_run_cmd(steps=2)
  180. args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)]
  181. setattr(self._server, '_status', ServerStatus.WAITING)
  182. res = self._server._wait_for_next_command()
  183. assert res == run_cmd
  184. assert getattr(self._server, '_old_run_cmd') == {'left_step_count': 1}
  185. def test_send_matadata(self):
  186. """Test SendMatadata interface."""
  187. res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock())
  188. assert res == get_ack_reply()
  189. def test_send_matadata_with_training_done(self):
  190. """Test SendMatadata interface."""
  191. res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock())
  192. assert res == get_ack_reply()
  193. def test_send_graph(self):
  194. """Test SendGraph interface."""
  195. res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock())
  196. assert res == get_ack_reply()
  197. def test_send_tensors(self):
  198. """Test SendTensors interface."""
  199. res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock())
  200. assert res == get_ack_reply()
  201. @mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id')
  202. @mock.patch.object(GraphHandler, 'get_node_name_by_full_name')
  203. def test_send_watchpoint_hit(self, *args):
  204. """Test SendWatchpointHits interface."""
  205. args[0].side_effect = [None, 'mock_full_name']
  206. watchpoint_hit = MockDataGenerator.get_watchpoint_hit()
  207. res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock())
  208. assert res == get_ack_reply()