You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mock_ms_client.py 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Mocked MindSpore debugger client."""
  16. from threading import Thread
  17. import grpc
  18. import numpy as np
  19. from mindinsight.debugger.proto import ms_graph_pb2
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata, WatchpointHit, Chunk, EventReply
  21. from mindinsight.debugger.proto.debug_grpc_pb2_grpc import EventListenerStub
  22. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  23. from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
  24. class MockDebuggerClient:
  25. """Mocked Debugger client."""
  26. def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1):
  27. channel = grpc.insecure_channel(hostname)
  28. self.stub = EventListenerStub(channel)
  29. self.flag = True
  30. self._step = 0
  31. self._watchpoint_id = 0
  32. self._leaf_node = []
  33. self._cur_node = ''
  34. self._backend = backend
  35. self._graph_num = graph_num
  36. def _clean(self):
  37. """Clean cache."""
  38. self._step = 0
  39. self._watchpoint_id = 0
  40. self._leaf_node = []
  41. self._cur_node = ''
  42. def get_thread_instance(self):
  43. """Get debugger client thread."""
  44. return MockDebuggerClientThread(self)
  45. def next_node(self, name=None):
  46. """Update the current node to next node."""
  47. if not self._cur_node:
  48. self._cur_node = self._leaf_node[0]
  49. return
  50. cur_index = self._leaf_node.index(self._cur_node)
  51. # if name is not None, go to the specified node.
  52. if not name:
  53. next_index = cur_index + 1
  54. else:
  55. next_index = self._leaf_node.index(name)
  56. # update step
  57. if next_index <= cur_index or next_index == len(self._leaf_node):
  58. self._step += 1
  59. # update current node
  60. if next_index == len(self._leaf_node):
  61. self._cur_node = self._leaf_node[0]
  62. else:
  63. self._cur_node = self._leaf_node[next_index]
  64. def command_loop(self):
  65. """Wait for the command."""
  66. total_steps = 100
  67. wait_flag = True
  68. while self.flag and wait_flag:
  69. if self._step > total_steps:
  70. self.send_metadata_cmd(training_done=True)
  71. return
  72. wait_flag = self._wait_cmd()
  73. def _wait_cmd(self):
  74. """Wait for command and deal with command."""
  75. metadata = self.get_metadata_cmd()
  76. response = self.stub.WaitCMD(metadata)
  77. assert response.status == EventReply.Status.OK
  78. if response.HasField('run_cmd'):
  79. self._deal_with_run_cmd(response)
  80. elif response.HasField('view_cmd'):
  81. for tensor in response.view_cmd.tensors:
  82. self.send_tensor_cmd(in_tensor=tensor)
  83. elif response.HasField('set_cmd'):
  84. self._watchpoint_id += 1
  85. elif response.HasField('exit'):
  86. self._watchpoint_id = 0
  87. self._step = 0
  88. return False
  89. return True
  90. def _deal_with_run_cmd(self, response):
  91. self._step += response.run_cmd.run_steps
  92. if response.run_cmd.run_level == 'node':
  93. self.next_node(response.run_cmd.node_name)
  94. if self._watchpoint_id > 0:
  95. self.send_watchpoint_hit()
  96. def get_metadata_cmd(self, training_done=False):
  97. """Construct metadata message."""
  98. metadata = Metadata()
  99. metadata.device_name = '0'
  100. metadata.cur_step = self._step
  101. metadata.cur_node = self._cur_node
  102. metadata.backend = self._backend
  103. metadata.training_done = training_done
  104. return metadata
  105. def send_metadata_cmd(self, training_done=False):
  106. """Send metadata command."""
  107. self._clean()
  108. metadata = self.get_metadata_cmd(training_done)
  109. response = self.stub.SendMetadata(metadata)
  110. assert response.status == EventReply.Status.OK
  111. if training_done is False:
  112. self.send_graph_cmd()
  113. print("finish")
  114. def send_graph_cmd(self):
  115. """Send graph to debugger server."""
  116. self._step = 1
  117. if self._graph_num > 1:
  118. chunks = []
  119. for i in range(self._graph_num):
  120. chunks.extend(self._get_graph_chunks('graph_' + str(i)))
  121. response = self.stub.SendMultiGraphs(self._generate_graph(chunks))
  122. else:
  123. chunks = self._get_graph_chunks()
  124. response = self.stub.SendGraph(self._generate_graph(chunks))
  125. assert response.status == EventReply.Status.OK
  126. # go to command loop
  127. self.command_loop()
  128. def _get_graph_chunks(self, graph_name='graph_0'):
  129. """Get graph chunks."""
  130. with open(GRAPH_PROTO_FILE, 'rb') as file_handle:
  131. content = file_handle.read()
  132. size = len(content)
  133. graph = ms_graph_pb2.GraphProto()
  134. graph.ParseFromString(content)
  135. graph.name = graph_name
  136. content = graph.SerializeToString()
  137. self._leaf_node = [node.full_name for node in graph.node]
  138. # the max limit of grpc data size is 4kb
  139. # split graph into 3kb per chunk
  140. chunk_size = 1024 * 1024 * 3
  141. chunks = []
  142. for index in range(0, size, chunk_size):
  143. sub_size = min(chunk_size, size - index)
  144. sub_chunk = Chunk(buffer=content[index: index + sub_size])
  145. chunks.append(sub_chunk)
  146. chunks[-1].finished = True
  147. return chunks
  148. @staticmethod
  149. def _generate_graph(chunks):
  150. """Construct graph generator."""
  151. for buffer in chunks:
  152. yield buffer
  153. def send_tensor_cmd(self, in_tensor=None):
  154. """Send tensor info with value."""
  155. response = self.stub.SendTensors(self.generate_tensor(in_tensor))
  156. assert response.status == EventReply.Status.OK
  157. @staticmethod
  158. def generate_tensor(in_tensor=None):
  159. """Generate tensor message."""
  160. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  161. tensors = [TensorProto(), TensorProto()]
  162. tensors[0].CopyFrom(in_tensor)
  163. tensors[0].data_type = DataType.DT_FLOAT32
  164. tensors[0].dims.extend([2, 3])
  165. tensors[1].CopyFrom(tensors[0])
  166. tensors[0].tensor_content = tensor_content[:12]
  167. tensors[1].tensor_content = tensor_content[12:]
  168. tensors[0].finished = 0
  169. tensors[1].finished = 1
  170. for sub_tensor in tensors:
  171. yield sub_tensor
  172. def send_watchpoint_hit(self):
  173. """Send watchpoint hit value."""
  174. tensors = [TensorProto(node_name='Default/TransData-op99', slot='0'),
  175. TensorProto(node_name='Default/optimizer-Momentum/ApplyMomentum-op25', slot='0')]
  176. response = self.stub.SendWatchpointHits(self._generate_hits(tensors))
  177. assert response.status == EventReply.Status.OK
  178. @staticmethod
  179. def _generate_hits(tensors):
  180. """Construct watchpoint hits."""
  181. for tensor in tensors:
  182. hit = WatchpointHit()
  183. hit.id = 1
  184. hit.tensor.CopyFrom(tensor)
  185. yield hit
  186. class MockDebuggerClientThread:
  187. """Mocked debugger client thread."""
  188. def __init__(self, debugger_client):
  189. self._debugger_client = debugger_client
  190. self._debugger_client_thread = Thread(target=debugger_client.send_metadata_cmd)
  191. def __enter__(self, backend='Ascend'):
  192. self._debugger_client.flag = True
  193. self._debugger_client_thread.start()
  194. return self._debugger_client_thread
  195. def __exit__(self, exc_type, exc_val, exc_tb):
  196. self._debugger_client_thread.join(timeout=3)
  197. self._debugger_client.flag = False