You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_grpc_server.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger grpc server."""
  16. import copy
  17. from functools import wraps
  18. import mindinsight
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  21. Streams, RunLevel, version_match
  22. from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum
  23. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  24. from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
  25. def debugger_wrap(func):
  26. """Wrapper for catch exception."""
  27. @wraps(func)
  28. def record_log(*args, **kwargs):
  29. try:
  30. return func(*args, **kwargs)
  31. except Exception as err:
  32. log.exception(err)
  33. raise err
  34. return record_log
  35. class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
  36. """The grpc server used to interactive with grpc client."""
  37. def __init__(self, cache_store, condition_mgr):
  38. """
  39. Initialize.
  40. Args:
  41. cache_store (DebuggerCache): Debugger cache store.
  42. """
  43. cache_store.initialize()
  44. self._cache_store = cache_store
  45. self._condition_mgr = condition_mgr
  46. # the next position of command queue to be queried
  47. self._pos = None
  48. # the status of grpc server, the value is in ServerStatus
  49. self._status = None
  50. # the run command cache, used to deal with left continue steps or nodes
  51. self._old_run_cmd = None
  52. # the view command cache, used to update tensor history through data queue
  53. self._received_view_cmd = None
  54. # the flag of receiving watch point hit
  55. self._received_hit = None
  56. self.init()
  57. def init(self):
  58. """Init debugger grpc server."""
  59. self._pos = '0'
  60. self._status = ServerStatus.PENDING
  61. self._old_run_cmd = {}
  62. self._received_view_cmd = {}
  63. self._received_hit = []
  64. self._cache_store.clean()
  65. @debugger_wrap
  66. def WaitCMD(self, request, context):
  67. """Wait for a command in DebuggerCache."""
  68. # check if graph have already received.
  69. log.info("Received WaitCMD at %s-th step.", request.cur_step)
  70. if self._status == ServerStatus.PENDING:
  71. log.warning("No graph received before WaitCMD.")
  72. reply = get_ack_reply(1)
  73. return reply
  74. # send graph if it has not been sent before
  75. self._pre_process(request)
  76. # deal with old command
  77. reply = self._deal_with_old_command()
  78. # wait for next command
  79. if reply is None:
  80. reply = self._wait_for_next_command()
  81. # check the reply
  82. if reply is None:
  83. reply = get_ack_reply(1)
  84. log.warning("Failed to get command event.")
  85. else:
  86. log.debug("Reply to WaitCMD: %s", reply)
  87. return reply
  88. def _pre_process(self, request):
  89. """Pre-process before dealing with command."""
  90. # check if version is mismatch, if mismatch, send mismatch info to UI
  91. if self._status == ServerStatus.MISMATCH:
  92. log.warning("Version of MindSpore and MindInsight are unmatched,"
  93. "waiting for user to terminate the script.")
  94. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  95. # put metadata into data queue
  96. metadata = metadata_stream.get(['state', 'debugger_version'])
  97. self._cache_store.put_data(metadata)
  98. return
  99. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  100. is_new_step = metadata_stream.step < request.cur_step
  101. is_new_node = metadata_stream.full_name != request.cur_node
  102. # clean cache data at the beginning of new step or node has been changed.
  103. if is_new_step or is_new_node:
  104. self._cache_store.clean_data()
  105. self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).clean_tensors(
  106. request.cur_step)
  107. if is_new_step:
  108. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean()
  109. # receive graph at the beginning of the training
  110. if self._status == ServerStatus.RECEIVE_GRAPH:
  111. self._send_graph_flag(metadata_stream)
  112. # receive new metadata
  113. if is_new_step or is_new_node:
  114. self._update_metadata(metadata_stream, request)
  115. self._send_received_tensor_tag()
  116. self._send_watchpoint_hit_flag()
  117. def _send_graph_flag(self, metadata_stream):
  118. """
  119. Send graph and metadata to UI.
  120. Args:
  121. metadata_stream (MetadataHandler): Metadata handler stream.
  122. """
  123. self._cache_store.clean_command()
  124. # receive graph in the beginning of the training
  125. self._status = ServerStatus.WAITING
  126. metadata_stream.state = ServerStatus.WAITING.value
  127. metadata = metadata_stream.get()
  128. res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get()
  129. res.update(metadata)
  130. self._cache_store.put_data(res)
  131. log.debug("Put graph into data queue.")
  132. def _update_metadata(self, metadata_stream, metadata_proto):
  133. """
  134. Update metadata.
  135. Args:
  136. metadata_stream (MetadataHandler): Metadata handler stream.
  137. metadata_proto (MetadataProto): Metadata proto send by client.
  138. """
  139. # put new metadata into cache
  140. metadata_stream.put(metadata_proto)
  141. # update current node name and graph name
  142. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)
  143. full_name = metadata_proto.cur_node
  144. graph_name = graph_stream.get_graph_id_by_full_name(
  145. full_name) if full_name else metadata_stream.graph_name
  146. cur_node = graph_stream.get_node_name_by_full_name(full_name, graph_name)
  147. metadata_stream.node_name = cur_node
  148. metadata_stream.graph_name = graph_name
  149. metadata = metadata_stream.get()
  150. self._cache_store.put_data(metadata)
  151. log.debug("Put new metadata into data queue.")
  152. def _send_received_tensor_tag(self):
  153. """Send received_finish_tag."""
  154. node_info = self._received_view_cmd.get('node_info')
  155. if not node_info or self._received_view_cmd.get('wait_for_tensor'):
  156. return
  157. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get(['step', 'state'])
  158. ret = {'receive_tensor': node_info.copy()}
  159. ret.update(metadata)
  160. self._cache_store.put_data(ret)
  161. self._received_view_cmd.clear()
  162. log.debug("Send receive tensor flag for %s", node_info)
  163. def _send_watchpoint_hit_flag(self):
  164. """Send Watchpoint hit flag."""
  165. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(
  166. 0)
  167. if not self._received_hit:
  168. return
  169. watchpoint_hits = self._received_hit
  170. self._received_hit = []
  171. for watchpoint_hit in watchpoint_hits:
  172. watchpoint_hit_stream.put(watchpoint_hit)
  173. watchpoint_hits_info = {'receive_watchpoint_hits': True}
  174. self._cache_store.put_data(watchpoint_hits_info)
  175. log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
  176. def _deal_with_old_command(self):
  177. """Deal with old command."""
  178. event = None
  179. while self._cache_store.has_command(self._pos) and event is None:
  180. event = self._get_next_command()
  181. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  182. # deal with continue run command
  183. if event is None and self._old_run_cmd:
  184. left_step_count = self._old_run_cmd.get('left_step_count')
  185. node_name = self._old_run_cmd.get('node_name')
  186. # node_name and left_step_count should not set at the same time
  187. if not (left_step_count or node_name) or (left_step_count and node_name):
  188. log.warning("Invalid old run command. %s", self._old_run_cmd)
  189. self._old_run_cmd.clear()
  190. return None
  191. if left_step_count:
  192. event = self._deal_with_left_continue_step(left_step_count)
  193. else:
  194. event = self._deal_with_left_continue_node(node_name)
  195. log.debug("Send old RunCMD.")
  196. return event
  197. def _deal_with_left_continue_step(self, left_step_count):
  198. """
  199. Construct run command with left continue step count.
  200. Args:
  201. left_step_count (int): The count of left steps to be executed.
  202. Returns:
  203. Event, the run command event.
  204. """
  205. event = get_ack_reply()
  206. event.run_cmd.run_steps = 1
  207. event.run_cmd.run_level = 'step'
  208. left_step_count = left_step_count - 1 if left_step_count > 0 else -1
  209. if not left_step_count:
  210. self._old_run_cmd.clear()
  211. else:
  212. self._old_run_cmd['left_step_count'] = left_step_count
  213. log.debug("Send old step RunCMD. Left step count: %s", left_step_count)
  214. return event
  215. def _deal_with_left_continue_node(self, node_name):
  216. """
  217. Construct run command with left continue nodes.
  218. Args:
  219. node_name (str): The target node name.
  220. Returns:
  221. Union[None, Event], the run command event.
  222. """
  223. cur_full_name = self._cache_store.get_stream_handler(Streams.METADATA).full_name
  224. if cur_full_name == node_name:
  225. log.info("Execute to target node: %s", node_name)
  226. self._old_run_cmd.clear()
  227. return None
  228. event = get_ack_reply()
  229. event.run_cmd.run_level = 'node'
  230. event.run_cmd.node_name = ''
  231. log.debug("Send old node RunCMD, cur node: %s, target node: %s", cur_full_name, node_name)
  232. return event
  233. def _wait_for_next_command(self):
  234. """
  235. Wait for next command.
  236. Returns:
  237. EventReply, the command event.
  238. """
  239. log.info("Start to wait for command.")
  240. if self._status != ServerStatus.MISMATCH:
  241. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  242. metadata_stream.state = ServerStatus.WAITING.value
  243. self._cache_store.put_data(metadata_stream.get())
  244. event = None
  245. while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]:
  246. log.debug("Wait for %s-th command", self._pos)
  247. event = self._get_next_command()
  248. return event
  249. def _get_next_command(self):
  250. """Get next command."""
  251. self._pos, event = self._cache_store.get_command(self._pos)
  252. if event is None:
  253. return event
  254. # deal with command
  255. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  256. if isinstance(event, dict):
  257. event = self._deal_with_view_cmd(event)
  258. elif event.HasField('run_cmd'):
  259. event = self._deal_with_run_cmd(event)
  260. self._cache_store.put_data(metadata_stream.get())
  261. elif event.HasField('exit'):
  262. self._cache_store.clean()
  263. self._cache_store.put_data(metadata_stream.get())
  264. log.debug("Clean cache for exit cmd.")
  265. else:
  266. self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd)
  267. log.debug("get set cmd.")
  268. return event
  269. def _deal_with_view_cmd(self, event):
  270. """
  271. Deal with view cmd.
  272. Args:
  273. event (dict): View command params.
  274. - view_cmd (EventReply): EventReply with view command.
  275. - node_name (str): The center node name for view command.
  276. - tensor_name (str): The center tensor name for view command.
  277. - graph_name (str): The graph name of center node.
  278. Returns:
  279. EventReply, view command to be sent to client.
  280. """
  281. view_cmd = event.pop('view_cmd', None)
  282. log.debug("Receive view cmd for node: %s.", event)
  283. if not (view_cmd and event):
  284. log.debug("Invalid view command. Ignore it.")
  285. return None
  286. self._received_view_cmd['node_info'] = event
  287. self._received_view_cmd['wait_for_tensor'] = True
  288. return view_cmd
  289. def _deal_with_run_cmd(self, event):
  290. """Deal with run cmd."""
  291. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  292. run_cmd = event.run_cmd
  293. # receive step command
  294. if run_cmd.run_level == RunLevel.STEP.value:
  295. # receive pause cmd
  296. if not run_cmd.run_steps:
  297. log.debug("Pause training and wait for next command.")
  298. self._old_run_cmd.clear()
  299. # update metadata state from sending to waiting
  300. metadata_stream.state = ServerStatus.WAITING.value
  301. return None
  302. # receive step cmd
  303. left_steps = run_cmd.run_steps - 1
  304. event.run_cmd.run_steps = 1
  305. if left_steps:
  306. self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1
  307. elif run_cmd.node_name:
  308. self._old_run_cmd['node_name'] = run_cmd.node_name
  309. run_cmd.node_name = ''
  310. # clean watchpoint hit cache
  311. if run_cmd.run_level == RunLevel.RECHECK.value:
  312. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean()
  313. log.debug("Receive RunCMD. Clean watchpoint hit cache.")
  314. # update metadata state from sending to running
  315. metadata_stream.state = ServerStatus.RUNNING.value
  316. return event
  317. @debugger_wrap
  318. def SendMetadata(self, request, context):
  319. """Send metadata into DebuggerCache."""
  320. log.info("Received Metadata.")
  321. if self._status != ServerStatus.PENDING:
  322. log.info("Re-initialize cache store when new session comes.")
  323. self.init()
  324. client_ip = context.peer().split(':', 1)[-1]
  325. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  326. reply = get_ack_reply()
  327. if request.training_done:
  328. log.info("The training from %s has finished.", client_ip)
  329. else:
  330. ms_version = request.ms_version
  331. if version_match(ms_version, mindinsight.__version__) is False:
  332. log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s",
  333. ms_version, mindinsight.__version__)
  334. self._status = ServerStatus.MISMATCH
  335. reply.version_matched = False
  336. metadata_stream.state = ServerStatus.MISMATCH.value
  337. else:
  338. log.info("version is matched.")
  339. reply.version_matched = True
  340. metadata_stream.debugger_version = {'ms': ms_version, 'mi': mindinsight.__version__}
  341. log.debug("Put ms_version from %s into cache.", client_ip)
  342. metadata_stream.put(request)
  343. metadata_stream.client_ip = client_ip
  344. log.debug("Put new metadata from %s into cache.", client_ip)
  345. # put metadata into data queue
  346. metadata = metadata_stream.get()
  347. self._cache_store.put_data(metadata)
  348. log.debug("Send the reply to %s.", client_ip)
  349. return reply
  350. @debugger_wrap
  351. def SendGraph(self, request_iterator, context):
  352. """Send graph into DebuggerCache."""
  353. log.info("Received graph.")
  354. reply = get_ack_reply()
  355. if self._status == ServerStatus.MISMATCH:
  356. log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.")
  357. return reply
  358. serial_graph = b""
  359. for chunk in request_iterator:
  360. serial_graph += chunk.buffer
  361. graph = GraphProto.FromString(serial_graph)
  362. log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node))
  363. graph_dict = {graph.name: graph}
  364. self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict)
  365. self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals(
  366. graph.const_vals)
  367. self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
  368. self._record_parameter_names()
  369. self._status = ServerStatus.RECEIVE_GRAPH
  370. log.debug("Send the reply for graph.")
  371. return reply
  372. @debugger_wrap
  373. def SendMultiGraphs(self, request_iterator, context):
  374. """Send graph into DebuggerCache."""
  375. log.info("Received multi_graphs.")
  376. reply = get_ack_reply()
  377. if self._status == ServerStatus.MISMATCH:
  378. log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.")
  379. return reply
  380. serial_graph = b""
  381. graph_dict = {}
  382. for chunk in request_iterator:
  383. serial_graph += chunk.buffer
  384. if chunk.finished:
  385. sub_graph = GraphProto.FromString(serial_graph)
  386. graph_dict[sub_graph.name] = sub_graph
  387. log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name,
  388. len(sub_graph.node))
  389. serial_graph = b""
  390. self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals(
  391. sub_graph.const_vals)
  392. self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict)
  393. self._record_parameter_names()
  394. self._status = ServerStatus.RECEIVE_GRAPH
  395. log.debug("Send the reply for graph.")
  396. return reply
  397. def _record_parameter_names(self):
  398. """Record parameter full names in tensor handler."""
  399. parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)\
  400. .search_in_graph(pattern={'node_category': TargetTypeEnum.PARAMETER.value})
  401. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0)
  402. for node in parameter_nodes:
  403. tensor_name = [node.full_name + ':0']
  404. tensor_stream.record_parameter_names(tensor_name)
  405. @debugger_wrap
  406. def SendTensors(self, request_iterator, context):
  407. """Send tensors into DebuggerCache."""
  408. log.info("Received tensor.")
  409. tensor_contents = []
  410. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0)
  411. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  412. step = metadata_stream.step
  413. for tensor in request_iterator:
  414. tensor_contents.append(tensor.tensor_content)
  415. if tensor.finished:
  416. update_flag = tensor_stream.put(
  417. {'step': step, 'tensor_proto': tensor, 'tensor_contents': tensor_contents})
  418. if self._received_view_cmd.get('wait_for_tensor') and update_flag:
  419. # update_flag is used to avoid querying empty tensors again
  420. self._received_view_cmd['wait_for_tensor'] = False
  421. log.debug("Set wait for tensor flag to False.")
  422. tensor_contents = []
  423. continue
  424. reply = get_ack_reply()
  425. return reply
  426. @debugger_wrap
  427. def SendWatchpointHits(self, request_iterator, context):
  428. """Send watchpoint hits info DebuggerCache."""
  429. log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd)
  430. self._old_run_cmd.clear()
  431. if self._cache_store.get_stream_handler(Streams.METADATA).state == ServerStatus.RUNNING.value:
  432. # if the client session is running a script, all the cached command should be cleared
  433. # when received watchpoint_hits.
  434. self._cache_store.clean_command()
  435. # save the watchpoint_hits data
  436. watchpoint_hits = []
  437. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  438. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)
  439. for watchpoint_hit_proto in request_iterator:
  440. node_full_name = watchpoint_hit_proto.tensor.node_name
  441. graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
  442. if not graph_name:
  443. log.warning("Cannot find node %s in graph. Skip it.", node_full_name)
  444. continue
  445. ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name)
  446. log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
  447. if not ui_node_name:
  448. log.info("Not support to show %s on graph.", node_full_name)
  449. continue
  450. watchpoint_hit = {
  451. 'tensor_proto': watchpoint_hit_proto.tensor,
  452. 'watchpoint': copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id)),
  453. 'node_name': ui_node_name,
  454. 'graph_name': graph_name
  455. }
  456. hit_params = {}
  457. for param in watchpoint_hit_proto.watch_condition.params:
  458. if param.name not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value,
  459. ParamNameEnum.RANGE_END_INCLUSIVE.value) \
  460. and watchpoint_hit_proto.error_code == 0:
  461. hit_params[param.name] = param.actual_value
  462. for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']):
  463. name = param['name']
  464. if name in hit_params.keys():
  465. watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name]
  466. else:
  467. watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None
  468. watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code
  469. watchpoint_hits.append(watchpoint_hit)
  470. self._received_hit = watchpoint_hits
  471. reply = get_ack_reply()
  472. return reply