You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_grpc_server.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger grpc server."""
  16. from functools import wraps
  17. from mindinsight.debugger.common.log import logger as log
  18. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  19. Streams
  20. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  21. from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
  22. def debugger_wrap(func):
  23. """Wrapper for catch exception."""
  24. @wraps(func)
  25. def record_log(*args, **kwargs):
  26. try:
  27. return func(*args, **kwargs)
  28. except Exception as err:
  29. log.exception(err)
  30. raise err
  31. return record_log
  32. class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
  33. """The grpc server used to interactive with grpc client."""
  34. def __init__(self, cache_store):
  35. """
  36. Initialize.
  37. Args:
  38. cache_store (DebuggerCache): Debugger cache store.
  39. """
  40. cache_store.initialize()
  41. self._cache_store = cache_store
  42. self._pos = None
  43. self._status = None
  44. self._continue_steps = None
  45. self._received_view_cmd = None
  46. self.init()
  47. def init(self):
  48. """Init debugger grpc server."""
  49. self._pos = '0'
  50. self._status = ServerStatus.PENDING
  51. self._continue_steps = 0
  52. self._received_view_cmd = {}
  53. self._cache_store.clean()
  54. @debugger_wrap
  55. def WaitCMD(self, request, context):
  56. """Wait for a command in DebuggerCache."""
  57. # check if graph have already received.
  58. log.info("Received WaitCMD at %s-th step.", request.cur_step)
  59. if self._status == ServerStatus.PENDING:
  60. log.warning("No graph received before WaitCMD.")
  61. reply = get_ack_reply(1)
  62. return reply
  63. # send graph if it has not been sent before
  64. self._pre_process(request)
  65. # deal with old command
  66. reply = self._deal_with_old_command()
  67. # wait for next command
  68. if reply is None:
  69. reply = self._wait_for_next_command()
  70. # check the reply
  71. if reply is None:
  72. reply = get_ack_reply(1)
  73. log.warning("Failed to get command event.")
  74. else:
  75. log.info("Reply to WaitCMD: %s", reply)
  76. return reply
  77. def _pre_process(self, request):
  78. """Pre-process before dealing with command."""
  79. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  80. is_new_step = metadata_stream.step < request.cur_step
  81. # clean cache data at the beginning of new step
  82. if is_new_step:
  83. self._cache_store.clean_data()
  84. self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
  85. # receive graph at the beginning of the training
  86. if self._status == ServerStatus.RECEIVE_GRAPH:
  87. self._send_graph_flag(metadata_stream)
  88. # receive new metadata
  89. if is_new_step or metadata_stream.full_name != request.cur_node:
  90. self._update_metadata(metadata_stream, request)
  91. self._send_received_tensor_tag()
  92. self._send_watchpoint_hit_flag()
  93. def _send_graph_flag(self, metadata_stream):
  94. """
  95. Send graph and metadata to UI.
  96. Args:
  97. metadata_stream (MetadataHandler): Metadata handler stream.
  98. """
  99. self._cache_store.clean_command()
  100. # receive graph in the beginning of the training
  101. self._status = ServerStatus.WAITING
  102. metadata_stream.state = 'waiting'
  103. metadata = metadata_stream.get()
  104. res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
  105. res.update(metadata)
  106. self._cache_store.put_data(res)
  107. log.debug("Put graph into data queue.")
  108. def _update_metadata(self, metadata_stream, metadata_proto):
  109. """
  110. Update metadata.
  111. Args:
  112. metadata_stream (MetadataHandler): Metadata handler stream.
  113. metadata_proto (MetadataProto): Metadata proto send by client.
  114. """
  115. # put new metadata into cache
  116. metadata_stream.put(metadata_proto)
  117. cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
  118. metadata_proto.cur_node) if metadata_proto.cur_node else ''
  119. metadata_stream.node_name = cur_node
  120. metadata = metadata_stream.get()
  121. self._cache_store.put_data(metadata)
  122. log.debug("Put new metadata into data queue.")
  123. def _send_received_tensor_tag(self):
  124. """Send received_finish_tag."""
  125. node_name = self._received_view_cmd.get('node_name')
  126. if not node_name or self._received_view_cmd.get('wait_for_tensor'):
  127. return
  128. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
  129. ret = {'receive_tensor': {'node_name': node_name}}
  130. ret.update(metadata)
  131. self._cache_store.put_data(ret)
  132. self._received_view_cmd.clear()
  133. log.debug("Send receive tensor flag for %s", node_name)
  134. def _send_watchpoint_hit_flag(self):
  135. """Send Watchpoint hit flag."""
  136. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  137. if watchpoint_hit_stream.empty:
  138. return
  139. watchpoint_hits_info = watchpoint_hit_stream.get()
  140. self._cache_store.put_data(watchpoint_hits_info)
  141. log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
  142. def _deal_with_old_command(self):
  143. """Deal with old command."""
  144. event = None
  145. while self._cache_store.has_command(self._pos) and event is None:
  146. event = self._get_next_command()
  147. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  148. # continue multiple steps training
  149. if event is None and self._continue_steps:
  150. event = get_ack_reply()
  151. event.run_cmd.run_steps = 1
  152. event.run_cmd.run_level = 'step'
  153. self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
  154. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  155. log.debug("Send RunCMD. Clean watchpoint hit.")
  156. return event
  157. def _wait_for_next_command(self):
  158. """
  159. Wait for next command.
  160. Returns:
  161. EventReply, the command event.
  162. """
  163. log.info("Start to wait for command.")
  164. self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
  165. self._cache_store.put_data({'metadata': {'state': 'waiting'}})
  166. event = None
  167. while event is None and self._status == ServerStatus.WAITING:
  168. log.debug("Wait for %s-th command", self._pos)
  169. event = self._get_next_command()
  170. return event
  171. def _get_next_command(self):
  172. """Get next command."""
  173. self._pos, event = self._cache_store.get_command(self._pos)
  174. if event is None:
  175. return event
  176. if isinstance(event, dict):
  177. event = self._deal_with_view_cmd(event)
  178. elif event.HasField('run_cmd'):
  179. event = self._deal_with_run_cmd(event)
  180. elif event.HasField('exit'):
  181. self._cache_store.clean()
  182. log.info("Clean cache for exit cmd.")
  183. return event
  184. def _deal_with_view_cmd(self, event):
  185. """Deal with view cmd."""
  186. view_cmd = event.get('view_cmd')
  187. node_name = event.get('node_name')
  188. log.debug("Receive view cmd for node: %s.", node_name)
  189. if not (view_cmd and node_name):
  190. log.debug("Invalid view command. Ignore it.")
  191. return None
  192. self._received_view_cmd['node_name'] = node_name
  193. self._received_view_cmd['wait_for_tensor'] = True
  194. return view_cmd
  195. def _deal_with_run_cmd(self, event):
  196. """Deal with run cmd."""
  197. run_cmd = event.run_cmd
  198. # receive step command
  199. if run_cmd.run_level == 'step':
  200. # receive pause cmd
  201. if run_cmd.run_steps == 0:
  202. log.debug("Pause training and wait for next command.")
  203. self._continue_steps = 0
  204. return None
  205. # receive step cmd
  206. self._continue_steps = run_cmd.run_steps - 1
  207. event.run_cmd.run_steps = 1
  208. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  209. log.debug("Receive RunCMD. Clean watchpoint hit cache.")
  210. return event
  211. @debugger_wrap
  212. def SendMetadata(self, request, context):
  213. """Send metadata into DebuggerCache."""
  214. log.info("Received Metadata.")
  215. if self._status != ServerStatus.PENDING:
  216. log.info("Re-initialize cache store when new session comes.")
  217. self.init()
  218. client_ip = context.peer().split(':', 1)[-1]
  219. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  220. if request.training_done:
  221. log.info("The training from %s has finished.", client_ip)
  222. else:
  223. metadata_stream.put(request)
  224. metadata_stream.client_ip = client_ip
  225. log.debug("Put new metadata from %s into cache.", client_ip)
  226. # put metadata into data queue
  227. metadata = metadata_stream.get()
  228. self._cache_store.put_data(metadata)
  229. reply = get_ack_reply()
  230. log.debug("Send the reply to %s.", client_ip)
  231. return reply
  232. @debugger_wrap
  233. def SendGraph(self, request_iterator, context):
  234. """Send graph into DebuggerCache."""
  235. log.info("Received graph.")
  236. serial_graph = b""
  237. for chunk in request_iterator:
  238. serial_graph += chunk.buffer
  239. graph = GraphProto.FromString(serial_graph)
  240. log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
  241. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
  242. self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
  243. self._status = ServerStatus.RECEIVE_GRAPH
  244. reply = get_ack_reply()
  245. log.debug("Send the reply for graph.")
  246. return reply
  247. @debugger_wrap
  248. def SendTensors(self, request_iterator, context):
  249. """Send tensors into DebuggerCache."""
  250. log.info("Received tensor.")
  251. tensor_construct = []
  252. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
  253. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  254. tensor_names = []
  255. step = metadata_stream.step
  256. for tensor in request_iterator:
  257. tensor_construct.append(tensor)
  258. if tensor.finished:
  259. update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
  260. if self._received_view_cmd.get('wait_for_tensor') and update_flag:
  261. self._received_view_cmd['wait_for_tensor'] = False
  262. log.debug("Set wait for tensor flag to False.")
  263. tensor_construct = []
  264. tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
  265. continue
  266. reply = get_ack_reply()
  267. return reply
  268. @debugger_wrap
  269. def SendWatchpointHits(self, request_iterator, context):
  270. """Send watchpoint hits info DebuggerCache."""
  271. log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
  272. self._continue_steps = 0
  273. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  274. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  275. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  276. for watchpoint_hit_proto in request_iterator:
  277. ui_node_name = graph_stream.get_node_name_by_full_name(
  278. watchpoint_hit_proto.tensor.node_name)
  279. log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
  280. if not ui_node_name:
  281. log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
  282. continue
  283. watchpoint_hit = {
  284. 'tensor_proto': watchpoint_hit_proto.tensor,
  285. 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
  286. 'node_name': ui_node_name
  287. }
  288. watchpoint_hit_stream.put(watchpoint_hit)
  289. reply = get_ack_reply()
  290. return reply