You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_session.py 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger server."""
  16. from functools import wraps
  17. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  18. from mindinsight.datavisual.utils.tools import to_float
  19. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  20. DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \
  21. DebuggerTensorHitError, DebuggerSetRecommendWatchpointsError, MindInsightException
  22. from mindinsight.debugger.common.log import LOGGER as log
  23. from mindinsight.debugger.common.utils import ServerStatus, \
  24. create_view_event_from_tensor_basic_info, Streams
  25. from mindinsight.debugger.conditionmgr.condition import ConditionContext
  26. from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
  27. from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints
  28. from mindinsight.debugger.debugger_cache import DebuggerCache
  29. from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerFactory
  30. from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo
  31. from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator
  32. from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator
  33. from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR
  34. def try_except(func):
  35. """Send latest metadata when catch exception."""
  36. @wraps(func)
  37. def send_latest_metadata(self, *args, **kwargs):
  38. try:
  39. return func(self, *args, **kwargs)
  40. except MindInsightException as err:
  41. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
  42. self.cache_store.put_data(metadata)
  43. log.info("Put latest metadata into data-queue.")
  44. raise err
  45. return send_latest_metadata
  46. class DebuggerSession:
  47. """The server manager of debugger."""
  48. def __init__(self, context):
  49. self.condition_mgr = ConditionMgr()
  50. self.cache_store = DebuggerCache()
  51. self.context = context
  52. self.back_server = DebuggerServerFactory().get_debugger_server(self.cache_store, context)
  53. @property
  54. def train_job(self):
  55. """The property of train job."""
  56. return self.context.train_job
  57. def get_condition_collections(self, train_id=""):
  58. """Get default condition_collections"""
  59. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  60. condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
  61. log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
  62. return self.condition_mgr.get_all_collections(condition_context)
  63. def set_recommended_watch_points(self, set_recommended, train_id=""):
  64. """Set recommended watch points."""
  65. if not isinstance(set_recommended, bool):
  66. log.error("Bool param should be given for set_recommended")
  67. raise DebuggerParamValueError("Bool param should be given.")
  68. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  69. if metadata_stream.recommendation_confirmed:
  70. log.error("User has confirmed setting recommended watchpoints")
  71. raise DebuggerSetRecommendWatchpointsError()
  72. metadata_stream.recommendation_confirmed = True
  73. condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
  74. log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
  75. res = metadata_stream.get(['state', 'enable_recheck'])
  76. if set_recommended:
  77. res['id'] = self._add_recommended_watchpoints(condition_context)
  78. return res
  79. def _add_recommended_watchpoints(self, condition_context):
  80. """Add predefined watchpoints."""
  81. log.debug("Add predefined watchpoints.")
  82. multi_card_graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  83. watchpoints = recommend_watchpoints(self.condition_mgr, multi_card_graph_stream, condition_context)
  84. watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  85. device_stream = self.cache_store.get_stream_handler(Streams.DEVICE)
  86. watch_points_ids = []
  87. for watchpoint in watchpoints:
  88. watch_points_id = watch_point_stream_handler.create_watchpoint(
  89. watch_condition=watchpoint.get_watch_condition_dict(),
  90. watch_nodes=watchpoint.watch_nodes,
  91. name=watchpoint.name,
  92. condition_mgr=self.condition_mgr,
  93. device_amount=device_stream.device_amount
  94. )
  95. watch_points_ids.append(watch_points_id)
  96. return watch_points_ids
  97. def start(self):
  98. """Start server."""
  99. self.back_server.start()
  100. # register stop server handler
  101. #signal.signal(signal.SIGINT, self._stop_handler)
  102. log.info("Start debugger backend server.")
  103. def _stop_handler(self, signum, frame):
  104. """Register stop server handler."""
  105. self.stop()
  106. log.debug("Deal with stop signal: %s, %s", signum, frame)
  107. def stop(self):
  108. """Stop debugger server."""
  109. log.info("Send terminate info to client.")
  110. self.control({'mode': 'terminate'})
  111. self.back_server.stop()
  112. log.info("Stop debugger server.")
  113. def poll_data(self, pos):
  114. """
  115. Get the pos-th data from DebuggerCache.
  116. Args:
  117. pos (int): The index of data.
  118. Returns:
  119. dict, the data to be updated.
  120. """
  121. if not isinstance(pos, str):
  122. log.error("Pos should be string. Received: %s", pos)
  123. raise DebuggerParamValueError("Pos should be string.")
  124. reply = self.cache_store.get_data(pos)
  125. return reply
  126. def search(self, filter_condition):
  127. """
  128. Search for single node in graph.
  129. Args:
  130. filter_condition (dict): Filter condition.
  131. - name (str): The name pattern.
  132. - graph_name (str): The graph name.
  133. - watch_point_id (int): The id of watchpoint. Default: 0.
  134. - node_category (str): The node_category. Default: None
  135. - rank_id (int): The id of rank. Default: 0.
  136. Returns:
  137. dict, the searched nodes.
  138. """
  139. log.info("receive search request with filter_condition: %s", filter_condition)
  140. # validate watchpoint id
  141. watch_point_id = filter_condition.pop('watch_point_id', 0)
  142. rank_id = filter_condition.pop('rank_id', 0)
  143. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  144. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  145. # validate and update graph name
  146. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
  147. graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
  148. filter_condition['graph_name'] = graph_name
  149. # get searched graph
  150. graph = graph_stream.search_nodes(filter_condition)
  151. # add watched label to graph
  152. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name, rank_id)
  153. return graph
  154. def tensor_comparisons(self, name, shape, detail='data', tolerance='0', rank_id=0):
  155. """
  156. Get tensor comparisons data for given name, detail, shape and tolerance.
  157. Args:
  158. name (str): The name of tensor for ui.
  159. detail (str): Specify which data to query. Current available value is 'data' which means
  160. concrete tensor data. Histogram or unique count can be supported in the future.
  161. shape (str): Specify concrete dimensions of shape.
  162. tolerance (str): Specify tolerance of difference between current step tensor and previous
  163. step tensor. Default value is 0.
  164. rank_id (int): The id of rank. Default: 0.
  165. Raises:
  166. DebuggerParamValueError, If node type is not parameter or value of detail is not support.
  167. DebuggerCompareTensorError, If MindSpore is not in waiting state.
  168. Returns:
  169. dict, the retrieved data.
  170. """
  171. if self.cache_store.get_stream_handler(
  172. Streams.METADATA).state != ServerStatus.WAITING.value:
  173. log.error("Failed to compare tensors as the MindSpore is not in waiting state.")
  174. raise DebuggerCompareTensorError(
  175. "Failed to compare tensors as the MindSpore is not in waiting state."
  176. )
  177. self.validate_tensor_param(name, detail)
  178. # Limit to query max two dimensions for tensor in table view.
  179. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  180. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  181. tolerance = to_float(tolerance, 'tolerance')
  182. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
  183. cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
  184. if node_type == NodeTypeEnum.PARAMETER.value:
  185. reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance, cur_step)
  186. else:
  187. raise DebuggerParamValueError(
  188. "The node type must be parameter, but got {}.".format(node_type))
  189. return reply
  190. def retrieve(self, mode, filter_condition=None):
  191. """
  192. Retrieve data according to mode and params.
  193. Args:
  194. mode (str): The type of info message.
  195. filter_condition (dict): The filter condition.
  196. Returns:
  197. dict, the retrieved data.
  198. """
  199. log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
  200. filter_condition)
  201. mode_mapping = {
  202. 'all': self._retrieve_all,
  203. 'node': self._retrieve_node,
  204. 'watchpoint': self._retrieve_watchpoint,
  205. }
  206. # validate param <mode>
  207. if mode not in mode_mapping.keys():
  208. log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
  209. "'watchpoint_hit'], but got %s.", mode_mapping)
  210. raise DebuggerParamValueError("Invalid mode.")
  211. # validate backend status
  212. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  213. if metadata_stream.state == ServerStatus.PENDING.value:
  214. log.info("The backend is in pending status.")
  215. return metadata_stream.get()
  216. filter_condition = {} if filter_condition is None else filter_condition
  217. reply = mode_mapping[mode](filter_condition)
  218. return reply
  219. def _retrieve_all(self, filter_condition=None):
  220. """Retrieve metadata, root graph and watchpoint list."""
  221. if filter_condition:
  222. log.error("No filter condition required for retrieve all request.")
  223. raise DebuggerParamTypeError("filter_condition should be empty.")
  224. self.cache_store.clean_data()
  225. log.info("Clean data queue cache when retrieve all request.")
  226. result = {}
  227. for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]:
  228. sub_res = self.cache_store.get_stream_handler(stream).get()
  229. result.update(sub_res)
  230. devices = result['devices']
  231. if not devices:
  232. graph = result['graph']
  233. metadata = result['metadata']
  234. device = {'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'),
  235. 'device_id': metadata.get('device_name', ''),
  236. 'graph_names': graph.get('graph_names', [])}
  237. devices.append(device)
  238. sub_res = self._hide_parameters_for_ui()
  239. result.update(sub_res)
  240. return result
  241. def _retrieve_node(self, filter_condition):
  242. """
  243. Retrieve node info.
  244. Args:
  245. filter_condition (dict): Filter condition.
  246. - name (str): The name of single node.
  247. - graph_name (str): The relative graph_name of the node.
  248. - single_node (bool): If False, return the sub-layer of single node. If True, return
  249. the node list from root node to single node.
  250. - watch_point_id (int): The id of watchpoint.
  251. Returns:
  252. dict, reply with graph.
  253. """
  254. log.debug("Retrieve node %s.", filter_condition)
  255. # validate node name
  256. node_name = filter_condition.get('name')
  257. rank_id = filter_condition.get('rank_id', 0)
  258. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
  259. graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
  260. if node_name:
  261. # validate node name
  262. graph_stream.get_node_type(node_name, graph_name)
  263. filter_condition['single_node'] = bool(filter_condition.get('single_node'))
  264. filter_condition['graph_name'] = graph_name
  265. reply = self._get_nodes_info(filter_condition)
  266. return reply
  267. def _get_nodes_info(self, filter_condition):
  268. """
  269. Get nodes info.
  270. Args:
  271. filter_condition (dict): The filter condition.
  272. - name (str): The node name.
  273. - graph_name (str): The relative graph_name of the node.
  274. - single_node (bool): If False, return the sub-layer of single node. If True, return
  275. the node list from root node to single node.
  276. - watch_point_id (int): The id of watchpoint.
  277. Returns:
  278. dict, reply with graph.
  279. """
  280. # validate watch_point_id
  281. rank_id = filter_condition.get('rank_id', 0)
  282. watch_point_id = filter_condition.get('watch_point_id', 0)
  283. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  284. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  285. # get graph
  286. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
  287. reply = graph_stream.get(filter_condition)
  288. graph = reply.get('graph')
  289. # add watched label to graph
  290. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'),
  291. rank_id)
  292. return reply
  293. def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0):
  294. """
  295. Retrieve tensor history for leaf node.
  296. Args:
  297. node_name (str): The name of leaf node.
  298. graph_name (str): The graph name. Default: None.
  299. rank_id (int): The id of rank. Default: 0.
  300. Returns:
  301. dict, the tensor history and metadata.
  302. """
  303. log.info("Retrieve tensor history for node: %s.", node_name)
  304. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  305. if metadata_stream.state == ServerStatus.PENDING.value:
  306. log.info("The backend is in pending status.")
  307. return metadata_stream.get(['state', 'step'])
  308. res = self._get_tensor_history(node_name, graph_name, rank_id)
  309. return res
  310. def _get_tensor_history(self, node_name, graph_name=None, rank_id=0):
  311. """
  312. Get tensor history for single node.
  313. Args:
  314. node_name (str): The name of leaf node.
  315. graph_name (str): The graph name. Default: None.
  316. rank_id (int): The id of rank. Default: 0.
  317. Returns:
  318. dict, the tensor history and metadata.
  319. """
  320. # get basic tensor history
  321. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
  322. tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
  323. # add tensor value for tensor history
  324. self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name, rank_id)
  325. # add hit label for tensor history
  326. self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).update_tensor_history(tensor_history, rank_id)
  327. # add metadata
  328. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step'])
  329. tensor_history.update(metadata)
  330. return tensor_history
  331. def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name, rank_id):
  332. """
  333. Add tensor value for_tensor_history and send ViewCMD if tensor value missed.
  334. Args:
  335. tensor_history (list[dict]): A list of tensor info, including name and type.
  336. node_name (str): The UI node name.
  337. graph_name (str): The graph name. Default: None.
  338. rank_id (int): The id of rank. Default: 0.
  339. Returns:
  340. dict, the tensor info.
  341. """
  342. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
  343. cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
  344. missed_tensors = tensor_stream.update_tensor_history(tensor_history, cur_step)
  345. if missed_tensors:
  346. view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
  347. self.cache_store.put_command(
  348. {'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name, 'rank_id': rank_id})
  349. log.debug("Send view cmd.")
  350. def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False, rank_id=0):
  351. """Retrieve the tensor value."""
  352. log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
  353. self.validate_tensor_param(name, detail)
  354. # Limit to query max two dimensions for tensor in table view.
  355. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  356. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name, rank_id)
  357. reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
  358. {'name': tensor_name,
  359. 'node_type': node_type,
  360. 'shape': parsed_shape,
  361. 'prev': prev},
  362. rank_id
  363. )
  364. reply['tensor_value']['name'] = name
  365. return reply
  366. def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None, rank_id=0):
  367. """
  368. Get inner tensor name and type by UI name.
  369. Args:
  370. name (str): Node name shown in UI.
  371. graph_name (Union[str, None]): The graph name, default is: None.
  372. rank_id (int): The id of rank. Default: 0.
  373. Returns:
  374. str, full name of tensor.
  375. str, node type of tensor.
  376. """
  377. node_name, slot = name.rsplit(':', 1)
  378. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
  379. graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name)
  380. node_type = graph_stream.get_node_type(node_name, graph_name)
  381. full_name = graph_stream.get_full_name(node_name, graph_name)
  382. tensor_name = full_name + ':' + slot
  383. return node_type, tensor_name
  384. @staticmethod
  385. def validate_tensor_param(name, detail):
  386. """Validate params for retrieve tensor request."""
  387. # validate name
  388. if not isinstance(name, str) or ':' not in name:
  389. log.error("Invalid tensor name. Received: %s", name)
  390. raise DebuggerParamValueError("Invalid tensor name.")
  391. # validate data
  392. if detail != 'data':
  393. log.error("Invalid detail value. Received: %s", detail)
  394. raise DebuggerParamValueError("Invalid detail value.")
  395. def _retrieve_watchpoint(self, filter_condition):
  396. """
  397. Retrieve watchpoint.
  398. Args:
  399. filter_condition (dict): Filter condition.
  400. - watch_point_id (int): The id of watchpoint. If not given, return all watchpoints.
  401. - name (str): The name of single node.
  402. - single_node (bool): If False, return the sub-layer of single node. If True, return
  403. the node list from root node to single node.
  404. Returns:
  405. dict, watch point list or relative graph.
  406. """
  407. watchpoint_id = filter_condition.get('watch_point_id', 0)
  408. if not watchpoint_id:
  409. reply = self._hide_parameters_for_ui()
  410. log.debug("Get condition of watchpoints.")
  411. else:
  412. reply = self._retrieve_node(filter_condition)
  413. log.debug("Get graph of %d-th watchpoint.", watchpoint_id)
  414. return reply
  415. def search_watchpoint_hits(self, group_condition):
  416. """
  417. Retrieve watchpoint hit.
  418. Args:
  419. group_condition (dict): Filter condition.
  420. - limit (int): The limit of each page.
  421. - offset (int): The offset of current page.
  422. - node_name (str): The retrieved node name.
  423. - graph_name (str): The retrieved graph name.
  424. - rank_id (int): The rank id.
  425. Returns:
  426. dict, watch point list or relative graph.
  427. """
  428. if not isinstance(group_condition, dict):
  429. log.error("Group condition for watchpoint-hits request should be a dict")
  430. raise DebuggerParamTypeError("Group condition for watchpoint-hits request should be a dict")
  431. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  432. if metadata_stream.state == ServerStatus.PENDING.value:
  433. log.info("The backend is in pending status.")
  434. return metadata_stream.get()
  435. rank_id = group_condition.pop('rank_id', 0)
  436. reply = {}
  437. multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  438. if multi_watchpoint_hit_stream.check_rank_id(rank_id):
  439. watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(rank_id)
  440. reply = watchpoint_hit_stream.group_by(group_condition)
  441. reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
  442. return reply
  443. def create_watchpoint(self, params):
  444. """
  445. Create watchpoint.
  446. Args:
  447. params (dict): Params for create watchpoint.
  448. - watch_condition (dict): The watch condition. The format is like:
  449. {
  450. "id": "tensor_too_large",
  451. "params": [
  452. {
  453. "name": "abs_mean_gt",
  454. "value": 1.1
  455. }
  456. ]
  457. }
  458. - id (str): Id of condition.
  459. - params (list[dict]): The list of param for this condition.
  460. - watch_nodes (list[str]): The list of node names.
  461. - watch_point_id (int): The id of watchpoint.
  462. - search_pattern (dict): The search pattern.
  463. - graph_name (str): The relative graph_name of the watched node.
  464. Returns:
  465. dict, the id of new watchpoint and metadata info.
  466. """
  467. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  468. return watchpoint_opt.create_watchpoint(params)
  469. def update_watchpoint(self, params):
  470. """
  471. Update watchpoint.
  472. Args:
  473. params (dict): Params for update watchpoint.
  474. - watch_point_id (int): The id of watchpoint.
  475. - watch_nodes (list[str]): The list of node names.
  476. - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
  477. 1 for add nodes to watch nodes.
  478. - search_pattern (dict): The search pattern.
  479. - graph_name (str): The relative graph_name of the watched node.
  480. Returns:
  481. dict, the metadata info.
  482. """
  483. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  484. return watchpoint_opt.update_watchpoint(params)
  485. def delete_watchpoint(self, watch_point_id=None):
  486. """
  487. Delete watchpoint.
  488. Args:
  489. watch_point_id (Union[None, int]): The id of watchpoint.
  490. If None, delete all watchpoints. Default: None.
  491. Returns:
  492. dict, the metadata info.
  493. """
  494. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  495. return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id)
  496. @try_except
  497. def control(self, params=None):
  498. """
  499. Control the training process.
  500. Args:
  501. params (dict): The control params.
  502. - mode (str): Acceptable control command, including `continue`,
  503. `pause` and `terminate`.
  504. - level (str): The control granularity, `node` level or `step` level.
  505. Default: `step`.
  506. - steps (int): Specify the steps that training should run.
  507. Used when `level` is `step`.
  508. - name (str): Specify the name of the node. Used when `level` is `node`.
  509. - graph_name (str): The graph name.
  510. Returns:
  511. dict, the response.
  512. """
  513. log.info("Receive control request: %s.", params)
  514. mode = params.pop('mode', None) if params else None
  515. training_controller = TrainingControlOperator(self.cache_store)
  516. training_controller.validate_mode(mode)
  517. return training_controller.control(mode, params)
  518. @try_except
  519. def recheck(self):
  520. """
  521. Recheck all watchpoints.
  522. Returns:
  523. dict, metadata info.
  524. """
  525. return TrainingControlOperator(self.cache_store).recheck()
  526. def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0):
  527. """
  528. Retrieve tensor graph.
  529. Args:
  530. tensor_name (str): The tensor name from UI.
  531. graph_name (str): The graph name.
  532. rank_id (int): The id of rank. Default: 0.
  533. Returns:
  534. dict, tensor graph object.
  535. """
  536. if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
  537. log.error("Failed to get tensor graph the MindSpore is not in waiting state.")
  538. raise DebuggerTensorGraphError
  539. log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name)
  540. tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name, rank_id)
  541. return tensor_graph_ops
  542. def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0):
  543. """
  544. Retrieve tensor hit information.
  545. Args:
  546. tensor_name (str): The tensor name from UI.
  547. graph_name (str): The graph name.
  548. rank_id (int): The id of rank. Default: 0.
  549. Returns:
  550. dict, tensor hit info.
  551. """
  552. if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
  553. log.error("Failed to get tensor hits as the MindSpore is not in waiting state.")
  554. raise DebuggerTensorHitError
  555. log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name)
  556. watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name, rank_id)
  557. return {'watch_points': watch_points}
  558. def _hide_parameters_for_ui(self):
  559. """
  560. Hide some parameters on ui.
  561. Returns:
  562. dict, watch point list.
  563. """
  564. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
  565. watch_points = reply.get('watch_points')
  566. for i, watch_point in enumerate(watch_points):
  567. watch_condition = watch_point.get('watch_condition')
  568. parameters = watch_condition.get('params')
  569. watch_condition_id = watch_condition.get('id')
  570. mgr_condition = self.condition_mgr.get_condition(watch_condition_id)
  571. ui_watch_condition = []
  572. for param in parameters:
  573. parameter_definition = mgr_condition.get_parameter_definition(param['name'])
  574. if not parameter_definition.visible_on_ui:
  575. continue
  576. ui_watch_condition.append(param)
  577. reply['watch_points'][i]['watch_condition']['params'] = ui_watch_condition
  578. return reply