You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_grpc_server.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger grpc server."""
  16. from functools import wraps
  17. import mindinsight.conditionmgr.recommender
  18. from mindinsight.debugger.common.log import LOGGER as log
  19. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  20. Streams, RunLevel
  21. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  22. from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
  23. from mindinsight.conditionmgr.condition import ConditionContext
  24. def debugger_wrap(func):
  25. """Wrapper for catch exception."""
  26. @wraps(func)
  27. def record_log(*args, **kwargs):
  28. try:
  29. return func(*args, **kwargs)
  30. except Exception as err:
  31. log.exception(err)
  32. raise err
  33. return record_log
  34. class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
  35. """The grpc server used to interactive with grpc client."""
  36. def __init__(self, cache_store, condition_mgr):
  37. """
  38. Initialize.
  39. Args:
  40. cache_store (DebuggerCache): Debugger cache store.
  41. """
  42. cache_store.initialize()
  43. self._cache_store = cache_store
  44. self._condition_mgr = condition_mgr
  45. # the next position of command queue to be queried
  46. self._pos = None
  47. # the status of grpc server, the value is in ServerStatus
  48. self._status = None
  49. # the run command cache, used to deal with left continue steps or nodes
  50. self._old_run_cmd = None
  51. # the view command cache, used to update tensor history through data queue
  52. self._received_view_cmd = None
  53. # the flag of receiving watch point hit
  54. self._received_hit = None
  55. self.init()
  56. def init(self):
  57. """Init debugger grpc server."""
  58. self._pos = '0'
  59. self._status = ServerStatus.PENDING
  60. self._old_run_cmd = {}
  61. self._received_view_cmd = {}
  62. self._received_hit = []
  63. self._cache_store.clean()
  64. @debugger_wrap
  65. def WaitCMD(self, request, context):
  66. """Wait for a command in DebuggerCache."""
  67. # check if graph have already received.
  68. log.info("Received WaitCMD at %s-th step.", request.cur_step)
  69. if self._status == ServerStatus.PENDING:
  70. log.warning("No graph received before WaitCMD.")
  71. reply = get_ack_reply(1)
  72. return reply
  73. # send graph if it has not been sent before
  74. self._pre_process(request)
  75. # deal with old command
  76. reply = self._deal_with_old_command()
  77. # wait for next command
  78. if reply is None:
  79. reply = self._wait_for_next_command()
  80. # check the reply
  81. if reply is None:
  82. reply = get_ack_reply(1)
  83. log.warning("Failed to get command event.")
  84. else:
  85. log.debug("Reply to WaitCMD: %s", reply)
  86. return reply
  87. def _add_predefined_watchpoints(self, condition_context):
  88. """Add predefined watchpoints."""
  89. log.debug("Add predefined watchpoints.")
  90. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  91. watchpoints = mindinsight.conditionmgr.recommender.recommend_watchpoints(self._condition_mgr, graph_stream,
  92. condition_context)
  93. watch_point_stream_handler = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  94. for watchpoint in watchpoints:
  95. watch_point_stream_handler.create_watchpoint(
  96. watch_condition=watchpoint.get_watch_condition_dict(),
  97. watch_nodes=watchpoint.watch_nodes,
  98. condition_mgr=self._condition_mgr
  99. )
  100. def _pre_process(self, request):
  101. """Pre-process before dealing with command."""
  102. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  103. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  104. is_new_step = metadata_stream.step < request.cur_step
  105. is_new_node = metadata_stream.full_name != request.cur_node
  106. # clean cache data at the beginning of new step or node has been changed.
  107. if is_new_step or is_new_node:
  108. self._cache_store.clean_data()
  109. if is_new_step:
  110. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  111. self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
  112. watchpoint_stream.clean_temp_cached_names()
  113. # receive graph at the beginning of the training
  114. if self._status == ServerStatus.RECEIVE_GRAPH:
  115. condition_context = ConditionContext(backend=request.backend, debugger_capability=(1, 0))
  116. self._add_predefined_watchpoints(condition_context)
  117. self._send_graph_flag(metadata_stream)
  118. # receive new metadata
  119. if is_new_step or is_new_node:
  120. self._update_metadata(metadata_stream, request)
  121. # save the full name of the node which MindSpore has stored the tensor.
  122. watchpoint_stream.add_temp_cached_name(request.cur_node)
  123. self._send_received_tensor_tag()
  124. self._send_watchpoint_hit_flag()
  125. def _send_graph_flag(self, metadata_stream):
  126. """
  127. Send graph and metadata to UI.
  128. Args:
  129. metadata_stream (MetadataHandler): Metadata handler stream.
  130. """
  131. self._cache_store.clean_command()
  132. # receive graph in the beginning of the training
  133. self._status = ServerStatus.WAITING
  134. metadata_stream.state = 'waiting'
  135. metadata = metadata_stream.get()
  136. res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
  137. res.update(metadata)
  138. self._cache_store.put_data(res)
  139. log.debug("Put graph into data queue.")
  140. def _update_metadata(self, metadata_stream, metadata_proto):
  141. """
  142. Update metadata.
  143. Args:
  144. metadata_stream (MetadataHandler): Metadata handler stream.
  145. metadata_proto (MetadataProto): Metadata proto send by client.
  146. """
  147. # put new metadata into cache
  148. metadata_stream.put(metadata_proto)
  149. # update current node name and graph name
  150. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  151. full_name = metadata_proto.cur_node
  152. graph_name = graph_stream.get_graph_id_by_full_name(
  153. full_name) if full_name else metadata_stream.graph_name
  154. cur_node = graph_stream.get_node_name_by_full_name(full_name, graph_name)
  155. metadata_stream.node_name = cur_node
  156. metadata_stream.graph_name = graph_name
  157. metadata = metadata_stream.get()
  158. self._cache_store.put_data(metadata)
  159. log.debug("Put new metadata into data queue.")
  160. def _send_received_tensor_tag(self):
  161. """Send received_finish_tag."""
  162. node_name = self._received_view_cmd.get('node_name')
  163. if not node_name or self._received_view_cmd.get('wait_for_tensor'):
  164. return
  165. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get(['step', 'state'])
  166. ret = {'receive_tensor': {'node_name': node_name}}
  167. ret.update(metadata)
  168. self._cache_store.put_data(ret)
  169. self._received_view_cmd.clear()
  170. log.debug("Send receive tensor flag for %s", node_name)
  171. def _send_watchpoint_hit_flag(self):
  172. """Send Watchpoint hit flag."""
  173. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  174. if not self._received_hit:
  175. return
  176. watchpoint_hits = self._received_hit
  177. self._received_hit = []
  178. for watchpoint_hit in watchpoint_hits:
  179. watchpoint_hit_stream.put(watchpoint_hit)
  180. watchpoint_hits_info = watchpoint_hit_stream.get()
  181. self._cache_store.put_data(watchpoint_hits_info)
  182. log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
  183. def _deal_with_old_command(self):
  184. """Deal with old command."""
  185. event = None
  186. while self._cache_store.has_command(self._pos) and event is None:
  187. event = self._get_next_command()
  188. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  189. # deal with continue run command
  190. if event is None and self._old_run_cmd:
  191. left_step_count = self._old_run_cmd.get('left_step_count')
  192. node_name = self._old_run_cmd.get('node_name')
  193. # node_name and left_step_count should not set at the same time
  194. if not (left_step_count or node_name) or (left_step_count and node_name):
  195. log.warning("Invalid old run command. %s", self._old_run_cmd)
  196. self._old_run_cmd.clear()
  197. return None
  198. if left_step_count:
  199. event = self._deal_with_left_continue_step(left_step_count)
  200. else:
  201. event = self._deal_with_left_continue_node(node_name)
  202. log.debug("Send old RunCMD. Clean watchpoint hit.")
  203. return event
  204. def _deal_with_left_continue_step(self, left_step_count):
  205. """
  206. Construct run command with left continue step count.
  207. Args:
  208. left_step_count (int): The count of left steps to be executed.
  209. Returns:
  210. Event, the run command event.
  211. """
  212. event = get_ack_reply()
  213. event.run_cmd.run_steps = 1
  214. event.run_cmd.run_level = 'step'
  215. left_step_count = left_step_count - 1 if left_step_count > 0 else -1
  216. if not left_step_count:
  217. self._old_run_cmd.clear()
  218. else:
  219. self._old_run_cmd['left_step_count'] = left_step_count
  220. log.debug("Send old step RunCMD. Left step count: %s", left_step_count)
  221. return event
  222. def _deal_with_left_continue_node(self, node_name):
  223. """
  224. Construct run command with left continue nodes.
  225. Args:
  226. node_name (str): The target node name.
  227. Returns:
  228. Union[None, Event], the run command event.
  229. """
  230. cur_full_name = self._cache_store.get_stream_handler(Streams.METADATA).full_name
  231. if cur_full_name == node_name:
  232. log.info("Execute to target node: %s", node_name)
  233. self._old_run_cmd.clear()
  234. return None
  235. event = get_ack_reply()
  236. event.run_cmd.run_level = 'node'
  237. event.run_cmd.node_name = ''
  238. log.debug("Send old node RunCMD, cur node: %s, target node: %s", cur_full_name, node_name)
  239. return event
  240. def _wait_for_next_command(self):
  241. """
  242. Wait for next command.
  243. Returns:
  244. EventReply, the command event.
  245. """
  246. log.info("Start to wait for command.")
  247. self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
  248. self._cache_store.put_data({'metadata': {'state': 'waiting'}})
  249. event = None
  250. while event is None and self._status == ServerStatus.WAITING:
  251. log.debug("Wait for %s-th command", self._pos)
  252. event = self._get_next_command()
  253. return event
  254. def _get_next_command(self):
  255. """Get next command."""
  256. self._pos, event = self._cache_store.get_command(self._pos)
  257. if event is None:
  258. return event
  259. if isinstance(event, dict):
  260. event = self._deal_with_view_cmd(event)
  261. elif event.HasField('run_cmd'):
  262. event = self._deal_with_run_cmd(event)
  263. elif event.HasField('exit'):
  264. self._cache_store.clean()
  265. log.debug("Clean cache for exit cmd.")
  266. else:
  267. self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd)
  268. log.debug("get set cmd.")
  269. return event
  270. def _deal_with_view_cmd(self, event):
  271. """Deal with view cmd."""
  272. view_cmd = event.get('view_cmd')
  273. node_name = event.get('node_name')
  274. log.debug("Receive view cmd for node: %s.", node_name)
  275. if not (view_cmd and node_name):
  276. log.debug("Invalid view command. Ignore it.")
  277. return None
  278. self._received_view_cmd['node_name'] = node_name
  279. self._received_view_cmd['wait_for_tensor'] = True
  280. return view_cmd
  281. def _deal_with_run_cmd(self, event):
  282. """Deal with run cmd."""
  283. run_cmd = event.run_cmd
  284. # receive step command
  285. if run_cmd.run_level == 'step':
  286. # receive pause cmd
  287. if not run_cmd.run_steps:
  288. log.debug("Pause training and wait for next command.")
  289. self._old_run_cmd.clear()
  290. return None
  291. # receive step cmd
  292. left_steps = run_cmd.run_steps - 1
  293. event.run_cmd.run_steps = 1
  294. if left_steps:
  295. self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1
  296. elif run_cmd.node_name:
  297. self._old_run_cmd['node_name'] = run_cmd.node_name
  298. run_cmd.node_name = ''
  299. # clean watchpoint hit cache
  300. if run_cmd.run_level == RunLevel.RECHECK.value:
  301. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  302. log.debug("Receive RunCMD. Clean watchpoint hit cache.")
  303. return event
  304. @debugger_wrap
  305. def SendMetadata(self, request, context):
  306. """Send metadata into DebuggerCache."""
  307. log.info("Received Metadata.")
  308. if self._status != ServerStatus.PENDING:
  309. log.info("Re-initialize cache store when new session comes.")
  310. self.init()
  311. client_ip = context.peer().split(':', 1)[-1]
  312. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  313. if request.training_done:
  314. log.info("The training from %s has finished.", client_ip)
  315. else:
  316. metadata_stream.put(request)
  317. metadata_stream.client_ip = client_ip
  318. log.debug("Put new metadata from %s into cache.", client_ip)
  319. # put metadata into data queue
  320. metadata = metadata_stream.get()
  321. self._cache_store.put_data(metadata)
  322. reply = get_ack_reply()
  323. log.debug("Send the reply to %s.", client_ip)
  324. return reply
  325. @debugger_wrap
  326. def SendGraph(self, request_iterator, context):
  327. """Send graph into DebuggerCache."""
  328. log.info("Received graph.")
  329. serial_graph = b""
  330. for chunk in request_iterator:
  331. serial_graph += chunk.buffer
  332. graph = GraphProto.FromString(serial_graph)
  333. log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node))
  334. graph_dict = {graph.name: graph}
  335. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
  336. self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
  337. self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
  338. self._status = ServerStatus.RECEIVE_GRAPH
  339. reply = get_ack_reply()
  340. log.debug("Send the reply for graph.")
  341. return reply
  342. @debugger_wrap
  343. def SendMultiGraphs(self, request_iterator, context):
  344. """Send graph into DebuggerCache."""
  345. log.info("Received graph.")
  346. serial_graph = b""
  347. graph_dict = {}
  348. for chunk in request_iterator:
  349. serial_graph += chunk.buffer
  350. if chunk.finished:
  351. sub_graph = GraphProto.FromString(serial_graph)
  352. graph_dict[sub_graph.name] = sub_graph
  353. log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name,
  354. len(sub_graph.node))
  355. serial_graph = b""
  356. self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(
  357. sub_graph.const_vals)
  358. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
  359. self._status = ServerStatus.RECEIVE_GRAPH
  360. reply = get_ack_reply()
  361. log.debug("Send the reply for graph.")
  362. return reply
  363. @debugger_wrap
  364. def SendTensors(self, request_iterator, context):
  365. """Send tensors into DebuggerCache."""
  366. log.info("Received tensor.")
  367. tensor_construct = []
  368. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
  369. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  370. tensor_names = []
  371. step = metadata_stream.step
  372. for tensor in request_iterator:
  373. tensor_construct.append(tensor)
  374. if tensor.finished:
  375. update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
  376. if self._received_view_cmd.get('wait_for_tensor') and update_flag:
  377. self._received_view_cmd['wait_for_tensor'] = False
  378. log.debug("Set wait for tensor flag to False.")
  379. tensor_construct = []
  380. tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
  381. continue
  382. reply = get_ack_reply()
  383. return reply
  384. @debugger_wrap
  385. def SendWatchpointHits(self, request_iterator, context):
  386. """Send watchpoint hits info DebuggerCache."""
  387. log.info("Received WatchpointHits. Left run cmd %s change to emtpy.", self._old_run_cmd)
  388. self._old_run_cmd.clear()
  389. if self._cache_store.get_stream_handler(Streams.METADATA).state == ServerStatus.RUNNING.value:
  390. # if the client session is running a script, all the cached command should be cleared
  391. # when received watchpoint_hits.
  392. self._cache_store.clean_command()
  393. # save the watchpoint_hits data
  394. watchpoint_hits = []
  395. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  396. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  397. for watchpoint_hit_proto in request_iterator:
  398. node_full_name = watchpoint_hit_proto.tensor.node_name
  399. graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
  400. ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name)
  401. log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
  402. if not ui_node_name:
  403. log.info("Not support to show %s on graph.", node_full_name)
  404. continue
  405. watchpoint_hit = {
  406. 'tensor_proto': watchpoint_hit_proto.tensor,
  407. 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
  408. 'node_name': ui_node_name,
  409. 'graph_name': graph_name
  410. }
  411. watchpoint_hits.append(watchpoint_hit)
  412. self._received_hit = watchpoint_hits
  413. reply = get_ack_reply()
  414. return reply