You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mock_ms_client.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Mocked MindSpore debugger client."""
  16. from threading import Thread
  17. from time import sleep
  18. import grpc
  19. import numpy as np
  20. from mindinsight.debugger.proto import ms_graph_pb2
  21. from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata, WatchpointHit, Chunk, EventReply
  22. from mindinsight.debugger.proto.debug_grpc_pb2_grpc import EventListenerStub
  23. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  24. from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
  25. class MockDebuggerClient:
  26. """Mocked Debugger client."""
  27. def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1, ms_version='1.2.0'):
  28. channel = grpc.insecure_channel(hostname)
  29. self.stub = EventListenerStub(channel)
  30. self.flag = True
  31. self._step = 0
  32. self._watchpoint_id = 0
  33. self._leaf_node = []
  34. self._cur_node = ''
  35. self._backend = backend
  36. self._graph_num = graph_num
  37. self._ms_version = ms_version
  38. def _clean(self):
  39. """Clean cache."""
  40. self._step = 0
  41. self._watchpoint_id = 0
  42. self._leaf_node = []
  43. self._cur_node = ''
  44. def get_thread_instance(self):
  45. """Get debugger client thread."""
  46. return MockDebuggerClientThread(self)
  47. def next_node(self, name=None):
  48. """Update the current node to next node."""
  49. if not self._cur_node:
  50. self._cur_node = self._leaf_node[0]
  51. return
  52. cur_index = self._leaf_node.index(self._cur_node)
  53. # if name is not None, go to the specified node.
  54. if not name:
  55. next_index = cur_index + 1
  56. else:
  57. next_index = self._leaf_node.index(name)
  58. # update step
  59. if next_index <= cur_index or next_index == len(self._leaf_node):
  60. self._step += 1
  61. # update current node
  62. if next_index == len(self._leaf_node):
  63. self._cur_node = self._leaf_node[0]
  64. else:
  65. self._cur_node = self._leaf_node[next_index]
  66. def command_loop(self):
  67. """Wait for the command."""
  68. total_steps = 100
  69. wait_flag = True
  70. while self.flag and wait_flag:
  71. if self._step > total_steps:
  72. sleep(0.5)
  73. self.send_metadata_cmd(training_done=True)
  74. return
  75. wait_flag = self._wait_cmd()
  76. def _wait_cmd(self):
  77. """Wait for command and deal with command."""
  78. metadata = self.get_metadata_cmd()
  79. response = self.stub.WaitCMD(metadata)
  80. assert response.status == EventReply.Status.OK
  81. if response.HasField('run_cmd'):
  82. self._deal_with_run_cmd(response)
  83. elif response.HasField('view_cmd'):
  84. for tensor in response.view_cmd.tensors:
  85. self.send_tensor_cmd(in_tensor=tensor)
  86. elif response.HasField('set_cmd'):
  87. self._watchpoint_id += 1
  88. elif response.HasField('exit'):
  89. self._watchpoint_id = 0
  90. self._step = 0
  91. return False
  92. return True
  93. def _deal_with_run_cmd(self, response):
  94. self._step += response.run_cmd.run_steps
  95. if response.run_cmd.run_level == 'node':
  96. self.next_node(response.run_cmd.node_name)
  97. if self._watchpoint_id > 0:
  98. self.send_watchpoint_hit()
  99. def get_metadata_cmd(self, training_done=False):
  100. """Construct metadata message."""
  101. metadata = Metadata()
  102. metadata.device_name = '0'
  103. metadata.cur_step = self._step
  104. metadata.cur_node = self._cur_node
  105. metadata.backend = self._backend
  106. metadata.training_done = training_done
  107. metadata.ms_version = self._ms_version
  108. return metadata
  109. def send_metadata_cmd(self, training_done=False):
  110. """Send metadata command."""
  111. self._clean()
  112. metadata = self.get_metadata_cmd(training_done)
  113. response = self.stub.SendMetadata(metadata)
  114. assert response.status == EventReply.Status.OK
  115. if response.HasField('version_matched') and response.version_matched is False:
  116. self.command_loop()
  117. if training_done is False:
  118. self.send_graph_cmd()
  119. def send_graph_cmd(self):
  120. """Send graph to debugger server."""
  121. self._step = 1
  122. if self._graph_num > 1:
  123. chunks = []
  124. for i in range(self._graph_num):
  125. chunks.extend(self._get_graph_chunks('graph_' + str(i)))
  126. response = self.stub.SendMultiGraphs(self._generate_graph(chunks))
  127. else:
  128. chunks = self._get_graph_chunks()
  129. response = self.stub.SendGraph(self._generate_graph(chunks))
  130. assert response.status == EventReply.Status.OK
  131. # go to command loop
  132. self.command_loop()
  133. def _get_graph_chunks(self, graph_name='graph_0'):
  134. """Get graph chunks."""
  135. with open(GRAPH_PROTO_FILE, 'rb') as file_handle:
  136. content = file_handle.read()
  137. size = len(content)
  138. graph = ms_graph_pb2.GraphProto()
  139. graph.ParseFromString(content)
  140. graph.name = graph_name
  141. content = graph.SerializeToString()
  142. self._leaf_node = [node.full_name for node in graph.node]
  143. # the max limit of grpc data size is 4kb
  144. # split graph into 3kb per chunk
  145. chunk_size = 1024 * 1024 * 3
  146. chunks = []
  147. for index in range(0, size, chunk_size):
  148. sub_size = min(chunk_size, size - index)
  149. sub_chunk = Chunk(buffer=content[index: index + sub_size])
  150. chunks.append(sub_chunk)
  151. chunks[-1].finished = True
  152. return chunks
  153. @staticmethod
  154. def _generate_graph(chunks):
  155. """Construct graph generator."""
  156. for buffer in chunks:
  157. yield buffer
  158. def send_tensor_cmd(self, in_tensor=None):
  159. """Send tensor info with value."""
  160. response = self.stub.SendTensors(self.generate_tensor(in_tensor))
  161. assert response.status == EventReply.Status.OK
  162. @staticmethod
  163. def generate_tensor(in_tensor=None):
  164. """Generate tensor message."""
  165. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  166. tensors = [TensorProto(), TensorProto()]
  167. tensors[0].CopyFrom(in_tensor)
  168. tensors[0].data_type = DataType.DT_FLOAT32
  169. tensors[0].dims.extend([2, 3])
  170. tensors[1].CopyFrom(tensors[0])
  171. tensors[0].tensor_content = tensor_content[:12]
  172. tensors[1].tensor_content = tensor_content[12:]
  173. tensors[0].finished = 0
  174. tensors[1].finished = 1
  175. for sub_tensor in tensors:
  176. yield sub_tensor
  177. def send_watchpoint_hit(self):
  178. """Send watchpoint hit value."""
  179. tensors = [TensorProto(node_name='Default/TransData-op99', slot='0'),
  180. TensorProto(node_name='Default/optimizer-Momentum/ApplyMomentum-op25', slot='0')]
  181. response = self.stub.SendWatchpointHits(self._generate_hits(tensors))
  182. assert response.status == EventReply.Status.OK
  183. @staticmethod
  184. def _generate_hits(tensors):
  185. """Construct watchpoint hits."""
  186. for tensor in tensors:
  187. hit = WatchpointHit()
  188. hit.id = 1
  189. hit.tensor.CopyFrom(tensor)
  190. yield hit
  191. class MockDebuggerClientThread:
  192. """Mocked debugger client thread."""
  193. def __init__(self, debugger_client):
  194. self._debugger_client = debugger_client
  195. self._debugger_client_thread = Thread(target=debugger_client.send_metadata_cmd)
  196. def __enter__(self, backend='Ascend'):
  197. self._debugger_client.flag = True
  198. self._debugger_client_thread.start()
  199. return self._debugger_client_thread
  200. def __exit__(self, exc_type, exc_val, exc_tb):
  201. self._debugger_client_thread.join(timeout=2)
  202. self._debugger_client.flag = False