You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_grpc_server.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger grpc server."""
  16. from functools import wraps
  17. from mindinsight.debugger.common.log import logger as log
  18. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  19. Streams
  20. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  21. from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
  22. def debugger_wrap(func):
  23. """Wrapper for catch exception."""
  24. @wraps(func)
  25. def record_log(*args, **kwargs):
  26. try:
  27. return func(*args, **kwargs)
  28. except Exception as err:
  29. log.exception(err)
  30. raise err
  31. return record_log
  32. class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
  33. """The grpc server used to interactive with grpc client."""
  34. def __init__(self, cache_store):
  35. """
  36. Initialize.
  37. Args:
  38. cache_store (DebuggerCache): Debugger cache store.
  39. """
  40. cache_store.initialize()
  41. self._cache_store = cache_store
  42. # the next position of command queue to be queried
  43. self._pos = None
  44. # the status of grpc server, the value is in ServerStatus
  45. self._status = None
  46. # the run command cache, used to deal with left continue steps or nodes
  47. self._old_run_cmd = None
  48. # the view command cache, used to update tensor history through data queue
  49. self._received_view_cmd = None
  50. # the flag of receiving watch point hit
  51. self._received_hit = None
  52. self.init()
  53. def init(self):
  54. """Init debugger grpc server."""
  55. self._pos = '0'
  56. self._status = ServerStatus.PENDING
  57. self._old_run_cmd = {}
  58. self._received_view_cmd = {}
  59. self._received_hit = False
  60. self._cache_store.clean()
  61. @debugger_wrap
  62. def WaitCMD(self, request, context):
  63. """Wait for a command in DebuggerCache."""
  64. # check if graph have already received.
  65. log.info("Received WaitCMD at %s-th step.", request.cur_step)
  66. if self._status == ServerStatus.PENDING:
  67. log.warning("No graph received before WaitCMD.")
  68. reply = get_ack_reply(1)
  69. return reply
  70. # send graph if it has not been sent before
  71. self._pre_process(request)
  72. # deal with old command
  73. reply = self._deal_with_old_command()
  74. # wait for next command
  75. if reply is None:
  76. reply = self._wait_for_next_command()
  77. # check the reply
  78. if reply is None:
  79. reply = get_ack_reply(1)
  80. log.warning("Failed to get command event.")
  81. else:
  82. log.info("Reply to WaitCMD: %s", reply)
  83. return reply
  84. def _pre_process(self, request):
  85. """Pre-process before dealing with command."""
  86. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  87. is_new_step = metadata_stream.step < request.cur_step
  88. # clean cache data at the beginning of new step
  89. if is_new_step:
  90. self._cache_store.clean_data()
  91. self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
  92. # receive graph at the beginning of the training
  93. if self._status == ServerStatus.RECEIVE_GRAPH:
  94. self._send_graph_flag(metadata_stream)
  95. # receive new metadata
  96. if is_new_step or metadata_stream.full_name != request.cur_node:
  97. self._update_metadata(metadata_stream, request)
  98. self._send_received_tensor_tag()
  99. self._send_watchpoint_hit_flag()
  100. def _send_graph_flag(self, metadata_stream):
  101. """
  102. Send graph and metadata to UI.
  103. Args:
  104. metadata_stream (MetadataHandler): Metadata handler stream.
  105. """
  106. self._cache_store.clean_command()
  107. # receive graph in the beginning of the training
  108. self._status = ServerStatus.WAITING
  109. metadata_stream.state = 'waiting'
  110. metadata = metadata_stream.get()
  111. res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
  112. res.update(metadata)
  113. self._cache_store.put_data(res)
  114. log.debug("Put graph into data queue.")
  115. def _update_metadata(self, metadata_stream, metadata_proto):
  116. """
  117. Update metadata.
  118. Args:
  119. metadata_stream (MetadataHandler): Metadata handler stream.
  120. metadata_proto (MetadataProto): Metadata proto send by client.
  121. """
  122. # put new metadata into cache
  123. metadata_stream.put(metadata_proto)
  124. cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
  125. metadata_proto.cur_node) if metadata_proto.cur_node else ''
  126. metadata_stream.node_name = cur_node
  127. metadata = metadata_stream.get()
  128. self._cache_store.put_data(metadata)
  129. log.debug("Put new metadata into data queue.")
  130. def _send_received_tensor_tag(self):
  131. """Send received_finish_tag."""
  132. node_name = self._received_view_cmd.get('node_name')
  133. if not node_name or self._received_view_cmd.get('wait_for_tensor'):
  134. return
  135. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
  136. ret = {'receive_tensor': {'node_name': node_name}}
  137. ret.update(metadata)
  138. self._cache_store.put_data(ret)
  139. self._received_view_cmd.clear()
  140. log.debug("Send receive tensor flag for %s", node_name)
  141. def _send_watchpoint_hit_flag(self):
  142. """Send Watchpoint hit flag."""
  143. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  144. if watchpoint_hit_stream.empty or not self._received_hit:
  145. return
  146. self._received_hit = False
  147. watchpoint_hits_info = watchpoint_hit_stream.get()
  148. self._cache_store.put_data(watchpoint_hits_info)
  149. log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
  150. def _deal_with_old_command(self):
  151. """Deal with old command."""
  152. event = None
  153. while self._cache_store.has_command(self._pos) and event is None:
  154. event = self._get_next_command()
  155. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  156. # deal with continue run command
  157. if event is None and self._old_run_cmd:
  158. left_step_count = self._old_run_cmd.get('left_step_count')
  159. node_name = self._old_run_cmd.get('node_name')
  160. # node_name and left_step_count should not set at the same time
  161. if not (left_step_count or node_name) or (left_step_count and node_name):
  162. log.warning("Invalid old run command. %s", self._old_run_cmd)
  163. self._old_run_cmd.clear()
  164. return None
  165. if left_step_count:
  166. event = self._deal_with_left_continue_step(left_step_count)
  167. else:
  168. event = self._deal_with_left_continue_node(node_name)
  169. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  170. log.debug("Send old RunCMD. Clean watchpoint hit.")
  171. return event
  172. def _deal_with_left_continue_step(self, left_step_count):
  173. """
  174. Construct run command with left continue step count.
  175. Args:
  176. left_step_count (int): The count of left steps to be executed.
  177. Returns:
  178. Event, the run command event.
  179. """
  180. event = get_ack_reply()
  181. event.run_cmd.run_steps = 1
  182. event.run_cmd.run_level = 'step'
  183. left_step_count = left_step_count - 1 if left_step_count > 0 else -1
  184. if not left_step_count:
  185. self._old_run_cmd.clear()
  186. else:
  187. self._old_run_cmd['left_step_count'] = left_step_count
  188. log.debug("Send old step RunCMD. Left step count: %s", left_step_count)
  189. return event
  190. def _deal_with_left_continue_node(self, node_name):
  191. """
  192. Construct run command with left continue nodes.
  193. Args:
  194. node_name (str): The target node name.
  195. Returns:
  196. Union[None, Event], the run command event.
  197. """
  198. cur_full_name = self._cache_store.get_stream_handler(Streams.METADATA).full_name
  199. if cur_full_name == node_name:
  200. log.info("Execute to target node: %s", node_name)
  201. self._old_run_cmd.clear()
  202. return None
  203. event = get_ack_reply()
  204. event.run_cmd.run_level = 'node'
  205. event.run_cmd.node_name = ''
  206. log.debug("Send old node RunCMD, cur node: %s, target node: %s", cur_full_name, node_name)
  207. return event
  208. def _wait_for_next_command(self):
  209. """
  210. Wait for next command.
  211. Returns:
  212. EventReply, the command event.
  213. """
  214. log.info("Start to wait for command.")
  215. self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
  216. self._cache_store.put_data({'metadata': {'state': 'waiting'}})
  217. event = None
  218. while event is None and self._status == ServerStatus.WAITING:
  219. log.debug("Wait for %s-th command", self._pos)
  220. event = self._get_next_command()
  221. return event
  222. def _get_next_command(self):
  223. """Get next command."""
  224. self._pos, event = self._cache_store.get_command(self._pos)
  225. if event is None:
  226. return event
  227. if isinstance(event, dict):
  228. event = self._deal_with_view_cmd(event)
  229. elif event.HasField('run_cmd'):
  230. event = self._deal_with_run_cmd(event)
  231. elif event.HasField('exit'):
  232. self._cache_store.clean()
  233. log.info("Clean cache for exit cmd.")
  234. return event
  235. def _deal_with_view_cmd(self, event):
  236. """Deal with view cmd."""
  237. view_cmd = event.get('view_cmd')
  238. node_name = event.get('node_name')
  239. log.debug("Receive view cmd for node: %s.", node_name)
  240. if not (view_cmd and node_name):
  241. log.debug("Invalid view command. Ignore it.")
  242. return None
  243. self._received_view_cmd['node_name'] = node_name
  244. self._received_view_cmd['wait_for_tensor'] = True
  245. return view_cmd
  246. def _deal_with_run_cmd(self, event):
  247. """Deal with run cmd."""
  248. run_cmd = event.run_cmd
  249. # receive step command
  250. if run_cmd.run_level == 'step':
  251. # receive pause cmd
  252. if not run_cmd.run_steps:
  253. log.debug("Pause training and wait for next command.")
  254. self._old_run_cmd.clear()
  255. return None
  256. # receive step cmd
  257. left_steps = run_cmd.run_steps - 1
  258. event.run_cmd.run_steps = 1
  259. if left_steps:
  260. self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1
  261. elif run_cmd.node_name:
  262. self._old_run_cmd['node_name'] = run_cmd.node_name
  263. run_cmd.node_name = ''
  264. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  265. log.debug("Receive RunCMD. Clean watchpoint hit cache.")
  266. return event
  267. @debugger_wrap
  268. def SendMetadata(self, request, context):
  269. """Send metadata into DebuggerCache."""
  270. log.info("Received Metadata.")
  271. if self._status != ServerStatus.PENDING:
  272. log.info("Re-initialize cache store when new session comes.")
  273. self.init()
  274. client_ip = context.peer().split(':', 1)[-1]
  275. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  276. if request.training_done:
  277. log.info("The training from %s has finished.", client_ip)
  278. else:
  279. metadata_stream.put(request)
  280. metadata_stream.client_ip = client_ip
  281. log.debug("Put new metadata from %s into cache.", client_ip)
  282. # put metadata into data queue
  283. metadata = metadata_stream.get()
  284. self._cache_store.put_data(metadata)
  285. reply = get_ack_reply()
  286. log.debug("Send the reply to %s.", client_ip)
  287. return reply
  288. @debugger_wrap
  289. def SendGraph(self, request_iterator, context):
  290. """Send graph into DebuggerCache."""
  291. log.info("Received graph.")
  292. serial_graph = b""
  293. for chunk in request_iterator:
  294. serial_graph += chunk.buffer
  295. graph = GraphProto.FromString(serial_graph)
  296. log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
  297. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
  298. self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
  299. self._status = ServerStatus.RECEIVE_GRAPH
  300. reply = get_ack_reply()
  301. log.debug("Send the reply for graph.")
  302. return reply
  303. @debugger_wrap
  304. def SendTensors(self, request_iterator, context):
  305. """Send tensors into DebuggerCache."""
  306. log.info("Received tensor.")
  307. tensor_construct = []
  308. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
  309. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  310. tensor_names = []
  311. step = metadata_stream.step
  312. for tensor in request_iterator:
  313. tensor_construct.append(tensor)
  314. if tensor.finished:
  315. update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
  316. if self._received_view_cmd.get('wait_for_tensor') and update_flag:
  317. self._received_view_cmd['wait_for_tensor'] = False
  318. log.debug("Set wait for tensor flag to False.")
  319. tensor_construct = []
  320. tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
  321. continue
  322. reply = get_ack_reply()
  323. return reply
  324. @debugger_wrap
  325. def SendWatchpointHits(self, request_iterator, context):
  326. """Send watchpoint hits info DebuggerCache."""
  327. log.info("Received WatchpointHits. Left run cmd %s change to emtpy.", self._old_run_cmd)
  328. self._old_run_cmd.clear()
  329. self._received_hit = True
  330. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  331. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  332. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  333. for watchpoint_hit_proto in request_iterator:
  334. ui_node_name = graph_stream.get_node_name_by_full_name(
  335. watchpoint_hit_proto.tensor.node_name)
  336. log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
  337. if not ui_node_name:
  338. log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
  339. continue
  340. watchpoint_hit = {
  341. 'tensor_proto': watchpoint_hit_proto.tensor,
  342. 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
  343. 'node_name': ui_node_name
  344. }
  345. watchpoint_hit_stream.put(watchpoint_hit)
  346. reply = get_ack_reply()
  347. return reply