You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mock_ms_client.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Mocked MindSpore debugger client."""
  16. from threading import Thread
  17. import grpc
  18. import numpy as np
  19. from mindinsight.debugger.proto import ms_graph_pb2
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata, WatchpointHit, Chunk, EventReply
  21. from mindinsight.debugger.proto.debug_grpc_pb2_grpc import EventListenerStub
  22. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  23. from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
  24. class MockDebuggerClient:
  25. """Mocked Debugger client."""
  26. def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1, ms_version='1.1.0'):
  27. channel = grpc.insecure_channel(hostname)
  28. self.stub = EventListenerStub(channel)
  29. self.flag = True
  30. self._step = 0
  31. self._watchpoint_id = 0
  32. self._leaf_node = []
  33. self._cur_node = ''
  34. self._backend = backend
  35. self._graph_num = graph_num
  36. self._ms_version = ms_version
  37. def _clean(self):
  38. """Clean cache."""
  39. self._step = 0
  40. self._watchpoint_id = 0
  41. self._leaf_node = []
  42. self._cur_node = ''
  43. def get_thread_instance(self):
  44. """Get debugger client thread."""
  45. return MockDebuggerClientThread(self)
  46. def next_node(self, name=None):
  47. """Update the current node to next node."""
  48. if not self._cur_node:
  49. self._cur_node = self._leaf_node[0]
  50. return
  51. cur_index = self._leaf_node.index(self._cur_node)
  52. # if name is not None, go to the specified node.
  53. if not name:
  54. next_index = cur_index + 1
  55. else:
  56. next_index = self._leaf_node.index(name)
  57. # update step
  58. if next_index <= cur_index or next_index == len(self._leaf_node):
  59. self._step += 1
  60. # update current node
  61. if next_index == len(self._leaf_node):
  62. self._cur_node = self._leaf_node[0]
  63. else:
  64. self._cur_node = self._leaf_node[next_index]
  65. def command_loop(self):
  66. """Wait for the command."""
  67. total_steps = 100
  68. wait_flag = True
  69. while self.flag and wait_flag:
  70. if self._step > total_steps:
  71. self.send_metadata_cmd(training_done=True)
  72. return
  73. wait_flag = self._wait_cmd()
  74. def _wait_cmd(self):
  75. """Wait for command and deal with command."""
  76. metadata = self.get_metadata_cmd()
  77. response = self.stub.WaitCMD(metadata)
  78. assert response.status == EventReply.Status.OK
  79. if response.HasField('run_cmd'):
  80. self._deal_with_run_cmd(response)
  81. elif response.HasField('view_cmd'):
  82. for tensor in response.view_cmd.tensors:
  83. self.send_tensor_cmd(in_tensor=tensor)
  84. elif response.HasField('set_cmd'):
  85. self._watchpoint_id += 1
  86. elif response.HasField('exit'):
  87. self._watchpoint_id = 0
  88. self._step = 0
  89. return False
  90. return True
  91. def _deal_with_run_cmd(self, response):
  92. self._step += response.run_cmd.run_steps
  93. if response.run_cmd.run_level == 'node':
  94. self.next_node(response.run_cmd.node_name)
  95. if self._watchpoint_id > 0:
  96. self.send_watchpoint_hit()
  97. def get_metadata_cmd(self, training_done=False):
  98. """Construct metadata message."""
  99. metadata = Metadata()
  100. metadata.device_name = '0'
  101. metadata.cur_step = self._step
  102. metadata.cur_node = self._cur_node
  103. metadata.backend = self._backend
  104. metadata.training_done = training_done
  105. metadata.ms_version = self._ms_version
  106. return metadata
  107. def send_metadata_cmd(self, training_done=False):
  108. """Send metadata command."""
  109. self._clean()
  110. metadata = self.get_metadata_cmd(training_done)
  111. response = self.stub.SendMetadata(metadata)
  112. assert response.status == EventReply.Status.OK
  113. if response.version_matched is False:
  114. self.command_loop()
  115. if training_done is False:
  116. self.send_graph_cmd()
  117. print("finish")
  118. def send_graph_cmd(self):
  119. """Send graph to debugger server."""
  120. self._step = 1
  121. if self._graph_num > 1:
  122. chunks = []
  123. for i in range(self._graph_num):
  124. chunks.extend(self._get_graph_chunks('graph_' + str(i)))
  125. response = self.stub.SendMultiGraphs(self._generate_graph(chunks))
  126. else:
  127. chunks = self._get_graph_chunks()
  128. response = self.stub.SendGraph(self._generate_graph(chunks))
  129. assert response.status == EventReply.Status.OK
  130. # go to command loop
  131. self.command_loop()
  132. def _get_graph_chunks(self, graph_name='graph_0'):
  133. """Get graph chunks."""
  134. with open(GRAPH_PROTO_FILE, 'rb') as file_handle:
  135. content = file_handle.read()
  136. size = len(content)
  137. graph = ms_graph_pb2.GraphProto()
  138. graph.ParseFromString(content)
  139. graph.name = graph_name
  140. content = graph.SerializeToString()
  141. self._leaf_node = [node.full_name for node in graph.node]
  142. # the max limit of grpc data size is 4kb
  143. # split graph into 3kb per chunk
  144. chunk_size = 1024 * 1024 * 3
  145. chunks = []
  146. for index in range(0, size, chunk_size):
  147. sub_size = min(chunk_size, size - index)
  148. sub_chunk = Chunk(buffer=content[index: index + sub_size])
  149. chunks.append(sub_chunk)
  150. chunks[-1].finished = True
  151. return chunks
  152. @staticmethod
  153. def _generate_graph(chunks):
  154. """Construct graph generator."""
  155. for buffer in chunks:
  156. yield buffer
  157. def send_tensor_cmd(self, in_tensor=None):
  158. """Send tensor info with value."""
  159. response = self.stub.SendTensors(self.generate_tensor(in_tensor))
  160. assert response.status == EventReply.Status.OK
  161. @staticmethod
  162. def generate_tensor(in_tensor=None):
  163. """Generate tensor message."""
  164. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  165. tensors = [TensorProto(), TensorProto()]
  166. tensors[0].CopyFrom(in_tensor)
  167. tensors[0].data_type = DataType.DT_FLOAT32
  168. tensors[0].dims.extend([2, 3])
  169. tensors[1].CopyFrom(tensors[0])
  170. tensors[0].tensor_content = tensor_content[:12]
  171. tensors[1].tensor_content = tensor_content[12:]
  172. tensors[0].finished = 0
  173. tensors[1].finished = 1
  174. for sub_tensor in tensors:
  175. yield sub_tensor
  176. def send_watchpoint_hit(self):
  177. """Send watchpoint hit value."""
  178. tensors = [TensorProto(node_name='Default/TransData-op99', slot='0'),
  179. TensorProto(node_name='Default/optimizer-Momentum/ApplyMomentum-op25', slot='0')]
  180. response = self.stub.SendWatchpointHits(self._generate_hits(tensors))
  181. assert response.status == EventReply.Status.OK
  182. @staticmethod
  183. def _generate_hits(tensors):
  184. """Construct watchpoint hits."""
  185. for tensor in tensors:
  186. hit = WatchpointHit()
  187. hit.id = 1
  188. hit.tensor.CopyFrom(tensor)
  189. yield hit
  190. class MockDebuggerClientThread:
  191. """Mocked debugger client thread."""
  192. def __init__(self, debugger_client):
  193. self._debugger_client = debugger_client
  194. self._debugger_client_thread = Thread(target=debugger_client.send_metadata_cmd)
  195. def __enter__(self, backend='Ascend'):
  196. self._debugger_client.flag = True
  197. self._debugger_client_thread.start()
  198. return self._debugger_client_thread
  199. def __exit__(self, exc_type, exc_val, exc_tb):
  200. self._debugger_client_thread.join(timeout=2)
  201. self._debugger_client.flag = False