You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_server.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger server."""
  16. import signal
  17. from concurrent import futures
  18. from threading import Thread
  19. import grpc
  20. from mindinsight.conf import settings
  21. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  22. from mindinsight.datavisual.utils.tools import to_float
  23. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  24. DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
  25. DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \
  26. DebuggerCompareTensorError
  27. from mindinsight.debugger.common.log import logger as log
  28. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  29. create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo
  30. from mindinsight.debugger.debugger_cache import DebuggerCache
  31. from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
  32. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  33. from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
  34. from mindinsight.utils.exceptions import MindInsightException
  35. from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR
  36. class DebuggerServer:
  37. """The server manager of debugger."""
  38. def __init__(self, grpc_port=None):
  39. self.grpc_port = grpc_port
  40. self.cache_store = DebuggerCache()
  41. self.grpc_server = DebuggerGrpcServer(self.cache_store)
  42. self.grpc_server_manager = None
  43. self.back_server = None
  44. def start(self):
  45. """Start server."""
  46. grpc_port = self.grpc_port if self.grpc_port else "50051"
  47. host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
  48. hostname = "{}:{}".format(host, grpc_port)
  49. # initialize a grpc server
  50. grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
  51. grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager)
  52. grpc_server_manager.add_insecure_port(hostname)
  53. grpc_server_manager.start()
  54. my_server_thread = Thread(target=grpc_server_manager.wait_for_termination)
  55. # start grpc server
  56. my_server_thread.start()
  57. self.back_server = my_server_thread
  58. self.grpc_server_manager = grpc_server_manager
  59. # register stop server handler
  60. signal.signal(signal.SIGINT, self._stop_handler)
  61. log.info("Start grpc server %s", hostname)
  62. def _stop_handler(self, signum, frame):
  63. """Register stop server handler."""
  64. self.stop()
  65. log.debug("Deal with stop signal: %s, %s", signum, frame)
  66. def stop(self):
  67. """Stop debugger server."""
  68. log.info("Send terminate info to client.")
  69. self.control({'mode': 'terminate'})
  70. self.grpc_server_manager.stop(grace=None)
  71. self.back_server.join()
  72. log.info("Stop debugger server.")
  73. def poll_data(self, pos):
  74. """
  75. Get the pos-th data from DebuggerCache.
  76. Args:
  77. pos (int): The index of data.
  78. Returns:
  79. dict, the data to be updated.
  80. """
  81. if not isinstance(pos, str):
  82. log.error("Pos should be string. Received: %s", pos)
  83. raise DebuggerParamValueError("Pos should be string.")
  84. reply = self.cache_store.get_data(pos)
  85. return reply
  86. def search(self, name, watch_point_id=0):
  87. """
  88. Search for single node in graph.
  89. Args:
  90. name (str): The name pattern.
  91. watch_point_id (int): The id of watchpoint. Default: 0.
  92. Returns:
  93. dict, the searched nodes.
  94. """
  95. log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id)
  96. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  97. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  98. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  99. graph = graph_stream.search_nodes(name)
  100. # add watched label to graph
  101. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id)
  102. return graph
  103. def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
  104. """
  105. Get tensor comparisons data for given name, detail, shape and tolerance.
  106. Args:
  107. name (str): The name of tensor for ui.
  108. detail (str): Specify which data to query. Current available value is 'data' which means
  109. concrete tensor data. Histogram or unique count can be supported in the future.
  110. shape (str): Specify concrete dimensions of shape.
  111. tolerance (str): Specify tolerance of difference between current step tensor and previous
  112. step tensor. Default value is 0.
  113. Raises:
  114. DebuggerParamValueError, If node type is not parameter or value of detail is not support.
  115. DebuggerCompareTensorError, If MindSpore is not in waiting state.
  116. Returns:
  117. dict, the retrieved data.
  118. """
  119. if self.cache_store.get_stream_handler(
  120. Streams.METADATA).state != ServerStatus.WAITING.value:
  121. log.error("Failed to compare tensors as the MindSpore is not in waiting state.")
  122. raise DebuggerCompareTensorError(
  123. "Failed to compare tensors as the MindSpore is not in waiting state."
  124. )
  125. self.validate_tensor_param(name, detail)
  126. # Limit to query max two dimensions for tensor in table view.
  127. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  128. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  129. tolerance = to_float(tolerance, 'tolerance')
  130. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  131. if node_type == NodeTypeEnum.PARAMETER.value:
  132. reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
  133. else:
  134. raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type))
  135. return reply
  136. def retrieve(self, mode, filter_condition=None):
  137. """
  138. Retrieve data according to mode and params.
  139. Args:
  140. mode (str): The type of info message.
  141. filter_condition (dict): The filter condition.
  142. Returns:
  143. dict, the retrieved data.
  144. """
  145. log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
  146. filter_condition)
  147. mode_mapping = {
  148. 'all': self._retrieve_all,
  149. 'node': self._retrieve_node,
  150. 'watchpoint': self._retrieve_watchpoint,
  151. 'watchpoint_hit': self._retrieve_watchpoint_hit
  152. }
  153. # validate param <mode>
  154. if mode not in mode_mapping.keys():
  155. log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
  156. "'watchpoint_hit'], but got %s.", mode_mapping)
  157. raise DebuggerParamValueError("Invalid mode.")
  158. # validate backend status
  159. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  160. if metadata_stream.state == ServerStatus.PENDING.value:
  161. log.info("The backend is in pending status.")
  162. return metadata_stream.get()
  163. filter_condition = {} if filter_condition is None else filter_condition
  164. reply = mode_mapping[mode](filter_condition)
  165. return reply
  166. def _retrieve_all(self, filter_condition=None):
  167. """Retrieve metadata, root graph and watchpoint list."""
  168. if filter_condition:
  169. log.error("No filter condition required for retrieve all request.")
  170. raise DebuggerParamTypeError("filter_condition should be empty.")
  171. self.cache_store.clean_data()
  172. log.info("Clean data queue cache when retrieve all request.")
  173. result = {}
  174. for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
  175. sub_res = self.cache_store.get_stream_handler(stream).get()
  176. result.update(sub_res)
  177. return result
  178. def _retrieve_node(self, filter_condition):
  179. """
  180. Retrieve node info.
  181. Args:
  182. filter_condition (dict): Filter condition.
  183. - name (str): The name of single node.
  184. - single_node (bool): If False, return the sub-layer of single node. If True, return
  185. the node list from root node to single node.
  186. - watch_point_id (int): The id of watchpoint.
  187. Returns:
  188. dict, reply with graph.
  189. """
  190. log.debug("Retrieve node %s.", filter_condition)
  191. # validate node name
  192. node_name = filter_condition.get('name')
  193. if node_name:
  194. self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
  195. filter_condition['single_node'] = bool(filter_condition.get('single_node'))
  196. reply = self._get_nodes_info(filter_condition)
  197. return reply
  198. def _get_nodes_info(self, filter_condition):
  199. """
  200. Get nodes info.
  201. Args:
  202. filter_condition (dict): The filter condition.
  203. - name (str): The node name.
  204. - single_node (bool): If False, return the sub-layer of single node. If True, return
  205. the node list from root node to single node.
  206. - watch_point_id (int): The id of watchpoint.
  207. Returns:
  208. dict, reply with graph.
  209. """
  210. # validate watch_point_id
  211. watch_point_id = filter_condition.get('watch_point_id', 0)
  212. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  213. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  214. # get graph
  215. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  216. reply = graph_stream.get(filter_condition)
  217. graph = reply.get('graph')
  218. # add watched label to graph
  219. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id)
  220. return reply
  221. def retrieve_tensor_history(self, node_name):
  222. """
  223. Retrieve tensor history for leaf node.
  224. Args:
  225. node_name (str): The name of leaf node.
  226. Returns:
  227. dict, the tensor history and metadata.
  228. """
  229. log.info("Retrieve tensor history for node: %s.", node_name)
  230. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  231. if metadata_stream.state == ServerStatus.PENDING.value:
  232. log.info("The backend is in pending status.")
  233. return metadata_stream.get()
  234. self._validate_leaf_name(node_name)
  235. res = self._get_tensor_history(node_name)
  236. return res
  237. def _validate_leaf_name(self, node_name):
  238. """Validate if the node is a leaf node."""
  239. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  240. node_type = graph_stream.get_node_type(node_name)
  241. if is_scope_type(node_type):
  242. log.error("Scope type node has no tensor history.")
  243. raise DebuggerParamValueError("Invalid leaf node name.")
  244. def _get_tensor_history(self, node_name):
  245. """
  246. Get tensor history for single node.
  247. Args:
  248. node_name (str): The name of leaf node.
  249. Returns:
  250. dict, the tensor history and metadata.
  251. """
  252. # get basic tensor history
  253. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  254. tensor_history = graph_stream.get_tensor_history(node_name)
  255. # add tensor value for tensor history
  256. self._add_tensor_value_for_tensor_history(tensor_history, node_name)
  257. # add hit label for tensor history
  258. watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  259. watchpoint_hit_stream.update_tensor_history(tensor_history)
  260. # add metadata
  261. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
  262. tensor_history.update(metadata)
  263. return tensor_history
  264. def _add_tensor_value_for_tensor_history(self, tensor_history, node_name):
  265. """
  266. Add tensor value for_tensor_history and send ViewCMD if tensor value missed.
  267. Args:
  268. tensor_history (list[dict]): A list of tensor info, including name and type.
  269. node_name (str): The UI node name.
  270. Returns:
  271. dict, the tensor info.
  272. """
  273. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  274. missed_tensors = tensor_stream.update_tensor_history(tensor_history)
  275. if missed_tensors:
  276. view_cmd = create_view_event_from_tensor_history(missed_tensors)
  277. self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name})
  278. log.debug("Send view cmd.")
  279. def retrieve_tensor_value(self, name, detail, shape):
  280. """Retrieve the tensor value."""
  281. log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
  282. self.validate_tensor_param(name, detail)
  283. # Limit to query max two dimensions for tensor in table view.
  284. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  285. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  286. reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
  287. {'name': tensor_name,
  288. 'node_type': node_type,
  289. 'shape': parsed_shape}
  290. )
  291. reply['tensor_value']['name'] = name
  292. return reply
  293. def _get_tensor_name_and_type_by_ui_name(self, name):
  294. """
  295. Get inner tensor name and type by UI name.
  296. Args:
  297. name (str): Node name shown in UI.
  298. Returns:
  299. str, full name of tensor.
  300. str, node type of tensor.
  301. """
  302. node_name, slot = name.rsplit(':', 1)
  303. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  304. node_type = graph_stream.get_node_type(node_name)
  305. full_name = graph_stream.get_full_name(node_name)
  306. tensor_name = full_name + ':' + slot
  307. return node_type, tensor_name
  308. @staticmethod
  309. def validate_tensor_param(name, detail):
  310. """Validate params for retrieve tensor request."""
  311. # validate name
  312. if not isinstance(name, str) or ':' not in name:
  313. log.error("Invalid tensor name. Received: %s", name)
  314. raise DebuggerParamValueError("Invalid tensor name.")
  315. # validate data
  316. if detail != 'data':
  317. log.error("Invalid detail value. Received: %s", detail)
  318. raise DebuggerParamValueError("Invalid detail value.")
  319. def _retrieve_watchpoint(self, filter_condition):
  320. """
  321. Retrieve watchpoint.
  322. Args:
  323. filter_condition (dict): Filter condition.
  324. - watch_point_id (int): The id of watchpoint. If not given, return all watchpoints.
  325. - name (str): The name of single node.
  326. - single_node (bool): If False, return the sub-layer of single node. If True, return
  327. the node list from root node to single node.
  328. Returns:
  329. dict, watch point list or relative graph.
  330. """
  331. watchpoint_id = filter_condition.get('watch_point_id', 0)
  332. if not watchpoint_id:
  333. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
  334. log.debug("Get condition of watchpoints.")
  335. else:
  336. reply = self._retrieve_node(filter_condition)
  337. log.debug("Get graph of %d-th watchpoint.", watchpoint_id)
  338. return reply
  339. def _retrieve_watchpoint_hit(self, filter_condition):
  340. """
  341. Retrieve watchpoint hit.
  342. Args:
  343. filter_condition (dict): Filter condition.
  344. - name (str): The name of single node.
  345. - single_node (bool): If False, return the sub-layer of single node. If True, return
  346. the node list from root node to single node.
  347. Returns:
  348. dict, watch point list or relative graph.
  349. """
  350. node_name = filter_condition.get('name')
  351. # get all watchpoint hit list
  352. if node_name is None:
  353. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
  354. return reply
  355. # get tensor history and graph of the hit node.
  356. self._validate_leaf_name(node_name)
  357. # get tensor history
  358. reply = self._get_tensor_history(node_name)
  359. log.debug("Get tensor history for watchpoint hit node.")
  360. # get single graph
  361. if filter_condition.get('single_node'):
  362. graph = self._get_nodes_info(filter_condition)
  363. reply.update(graph)
  364. log.debug("Get tensor history for watchpoint hit node.")
  365. return reply
  366. def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
  367. """
  368. Create watchpoint.
  369. Args:
  370. watch_condition (dict): The watch condition.
  371. - condition (str): Accept `INF` or `NAN`.
  372. - param (list[float]): Not defined yet.
  373. watch_nodes (list[str]): The list of node names.
  374. watch_point_id (int): The id of watchpoint.
  375. Returns:
  376. dict, the id of new watchpoint.
  377. """
  378. log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
  379. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  380. if metadata_stream.state != ServerStatus.WAITING.value:
  381. log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
  382. raise DebuggerCreateWatchPointError(
  383. "Failed to create watchpoint as the MindSpore is not in waiting state.")
  384. if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW':
  385. log.error("GPU doesn't support OVERFLOW watch condition.")
  386. raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.")
  387. watch_nodes = self._get_node_basic_infos(watch_nodes)
  388. watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint(
  389. watch_condition, watch_nodes, watch_point_id)
  390. log.info("Create watchpoint %d", watch_point_id)
  391. return {'id': watch_point_id}
  392. def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None):
  393. """
  394. Update watchpoint.
  395. Args:
  396. watch_point_id (int): The id of watchpoint.
  397. watch_nodes (list[str]): The list of node names.
  398. mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
  399. 1 for add nodes to watch nodes.
  400. name (str): The search name. Default: None.
  401. Returns:
  402. dict, empty response.
  403. """
  404. if self.cache_store.get_stream_handler(
  405. Streams.METADATA).state != ServerStatus.WAITING.value:
  406. log.error("Failed to update watchpoint as the MindSpore is not in waiting state.")
  407. raise DebuggerUpdateWatchPointError(
  408. "Failed to update watchpoint as the MindSpore is not in waiting state."
  409. )
  410. # validate parameter
  411. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  412. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  413. if not watch_nodes or not watch_point_id:
  414. log.error("Invalid parameter for update watchpoint.")
  415. raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
  416. # update watch node
  417. if name is not None:
  418. watch_nodes = self._get_watch_nodes_by_search(watch_nodes)
  419. elif mode == 1:
  420. watch_nodes = self._get_node_basic_infos(watch_nodes)
  421. watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode)
  422. log.info("Update watchpoint with id: %d", watch_point_id)
  423. return {}
  424. def _get_watch_nodes_by_search(self, watch_nodes):
  425. """Get watched leaf nodes by search name."""
  426. watched_leaf_nodes = []
  427. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  428. for search_name in watch_nodes:
  429. search_nodes = graph_stream.get_searched_node_list()
  430. search_node_names = [
  431. NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
  432. for node in search_nodes
  433. if node.name.startswith(search_name)]
  434. watched_leaf_nodes.extend(search_node_names)
  435. log.debug("Update nodes: %s", watched_leaf_nodes)
  436. return watched_leaf_nodes
  437. def delete_watchpoint(self, watch_point_id):
  438. """
  439. Delete watchpoint.
  440. Args:
  441. watch_point_id (int): The id of watchpoint.
  442. Returns:
  443. dict, empty response.
  444. """
  445. if self.cache_store.get_stream_handler(
  446. Streams.METADATA).state != ServerStatus.WAITING.value:
  447. log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.")
  448. raise DebuggerDeleteWatchPointError(
  449. "Failed to delete watchpoint as the MindSpore is not in waiting state."
  450. )
  451. self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(watch_point_id)
  452. log.info("Delete watchpoint with id: %d", watch_point_id)
  453. return {}
  454. def _get_node_basic_infos(self, node_names):
  455. """Get node info according to node names."""
  456. if not node_names:
  457. return []
  458. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  459. node_infos = []
  460. for node_name in node_names:
  461. node_type = graph_stream.get_node_type(node_name)
  462. if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
  463. sub_nodes = graph_stream.get_nodes_by_scope(node_name)
  464. sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
  465. for node in sub_nodes]
  466. node_infos.extend(sub_infos)
  467. full_name = graph_stream.get_full_name(node_name)
  468. node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type))
  469. return node_infos
  470. def control(self, params=None):
  471. """
  472. Control the training process.
  473. Args:
  474. params (dict): The control params.
  475. - mode (str): Acceptable control command, including `continue`,
  476. `pause` and `terminate`.
  477. - level (str): The control granularity, `node` level or `step` level.
  478. Default: `step`.
  479. - steps (int): Specify the steps that training should run.
  480. Used when `level` is `step`.
  481. - name (str): Specify the name of the node. Used when `level` is `node`.
  482. Returns:
  483. dict, the response.
  484. """
  485. log.info("Receive control request: %s.", params)
  486. mode = params.get('mode')
  487. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  488. if mode == 'continue':
  489. reply = self._continue(metadata_stream, params)
  490. elif mode in ['pause', 'terminate']:
  491. mode_mapping = {
  492. 'pause': self._pause,
  493. 'terminate': self._terminate
  494. }
  495. reply = mode_mapping.get(mode)(metadata_stream)
  496. else:
  497. log.error("Invalid control mode %s", mode)
  498. raise DebuggerParamValueError("Invalid control mode.")
  499. return reply
  500. def _continue(self, metadata_stream, params):
  501. """
  502. Send RunCMD to MindSpore.
  503. Args:
  504. metadata_stream (MetadataHandler): The metadata_handler
  505. params (dict): The control params.
  506. """
  507. if metadata_stream.state != ServerStatus.WAITING.value:
  508. log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
  509. raise DebuggerContinueError(
  510. "MindSpore is not ready to run or is running currently."
  511. )
  512. metadata_stream.state = ServerStatus.RUNNING.value
  513. current_state = ServerStatus.RUNNING.value
  514. try:
  515. event = self._construct_run_event(params)
  516. self._send_watchpoints()
  517. self.cache_store.put_command(event)
  518. except MindInsightException as err:
  519. log.error("Failed to send run event.")
  520. log.exception(err)
  521. current_state = ServerStatus.WAITING.value
  522. metadata_stream.state = current_state
  523. raise DebuggerContinueError("Failed to send run command.")
  524. else:
  525. log.debug("Send the RunCMD to command queue.")
  526. return {'metadata': {'state': current_state}}
  527. def _construct_run_event(self, params):
  528. """
  529. Construct run cmd from input control params.
  530. Args:
  531. params (dict): The control params.
  532. - level (str): The control granularity, `node` level or `step` level.
  533. Default: `step`.
  534. - steps (int): Specify the steps that training should run.
  535. Used when `level` is `step`.
  536. - name (str): Specify the name of the node. Used when `level` is `node`.
  537. Returns:
  538. EventReply, control event with run command.
  539. """
  540. level = params.get('level', 'step')
  541. event = get_ack_reply()
  542. if level == 'step':
  543. steps = params.get('steps')
  544. if not steps:
  545. steps = 1
  546. run_cmd = RunCMD(run_level='step', run_steps=steps)
  547. elif level == 'node':
  548. name = params.get('name')
  549. if name:
  550. self._validate_leaf_name(name)
  551. name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name)
  552. else:
  553. name = ''
  554. run_cmd = RunCMD(run_level='node', node_name=name)
  555. else:
  556. log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level)
  557. raise DebuggerParamValueError("level` should be `step` or `node`")
  558. event.run_cmd.CopyFrom(run_cmd)
  559. log.debug("Construct run event. %s", event)
  560. return event
  561. def _send_watchpoints(self):
  562. """Set watchpoints."""
  563. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  564. watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points')
  565. if watchpoints:
  566. for watchpoint in watchpoints:
  567. event = get_ack_reply()
  568. event.set_cmd.CopyFrom(watchpoint)
  569. self.cache_store.put_command(event)
  570. watchpoint_stream.sync_set_cmd()
  571. log.debug("Send SetCMD to MindSpore. %s", event)
  572. def _pause(self, metadata_stream):
  573. """
  574. Pause the training.
  575. Args:
  576. metadata_stream (MetadataHandler): The metadata stream handler.
  577. """
  578. if metadata_stream.state != ServerStatus.RUNNING.value:
  579. log.error("The MindSpore is not running.")
  580. raise DebuggerPauseError("The MindSpore is not running.")
  581. metadata_stream.state = 'waiting'
  582. event = get_ack_reply()
  583. event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
  584. self.cache_store.put_command(event)
  585. log.debug("Send the Pause command")
  586. return {'metadata': {'state': 'waiting'}}
  587. def _terminate(self, metadata_stream):
  588. """
  589. Terminate the training.
  590. Args:
  591. metadata_stream (MetadataHandler): The metadata stream handler.
  592. """
  593. metadata_stream.state = 'pending'
  594. self.cache_store.clean_data()
  595. event = get_ack_reply()
  596. event.exit = True
  597. self.cache_store.put_command(event)
  598. log.debug("Send the ExitCMD.")
  599. return {'metadata': {'state': 'pending'}}
  600. def retrieve_node_by_bfs(self, node_name, ascend=False):
  601. """
  602. Get the graph of the next node according to node_name.
  603. Args:
  604. node_name (str): The name of current chosen leaf node.
  605. ascend (bool): If True, traverse the input nodes;
  606. If False, traverse the output nodes. Default is True.
  607. Returns:
  608. dict, the next node information.
  609. """
  610. log.info("Retrieve node <%s> by bfs, `ascend` is :%s",
  611. node_name, ascend)
  612. reply = {}
  613. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  614. next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
  615. # no next node
  616. if next_node_name is None:
  617. return reply
  618. # add graph and tensor history for next node
  619. filter_condition = {
  620. 'name': next_node_name,
  621. 'single_node': True
  622. }
  623. search_graph = self._get_nodes_info(filter_condition)
  624. reply = {'name': next_node_name}
  625. reply.update(search_graph)
  626. return reply