You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_server.py 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger server."""
  16. import signal
  17. from concurrent import futures
  18. from functools import wraps
  19. from threading import Thread
  20. import grpc
  21. from mindinsight.debugger.conditionmgr.condition import ConditionContext
  22. from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
  23. from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints
  24. from mindinsight.conf import settings
  25. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  26. from mindinsight.datavisual.utils.tools import to_float
  27. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  28. DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \
  29. DebuggerTensorHitError, DebuggerSetRecommendWatchpointsError, MindInsightException
  30. from mindinsight.debugger.common.log import LOGGER as log
  31. from mindinsight.debugger.common.utils import ServerStatus, \
  32. create_view_event_from_tensor_basic_info, Streams
  33. from mindinsight.debugger.debugger_cache import DebuggerCache
  34. from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
  35. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  36. from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo
  37. from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator
  38. from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator
  39. from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR
  40. def try_except(func):
  41. """Send latest metadata when catch exception."""
  42. @wraps(func)
  43. def send_latest_metadata(self, *args, **kwargs):
  44. try:
  45. return func(self, *args, **kwargs)
  46. except MindInsightException as err:
  47. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
  48. self.cache_store.put_data(metadata)
  49. log.info("Put latest metadata into data-queue.")
  50. raise err
  51. return send_latest_metadata
  52. class DebuggerServer:
  53. """The server manager of debugger."""
  54. def __init__(self):
  55. self.condition_mgr = ConditionMgr()
  56. self.cache_store = DebuggerCache()
  57. self.grpc_server = DebuggerGrpcServer(self.cache_store, self.condition_mgr)
  58. self.grpc_server_manager = None
  59. self.back_server = None
  60. def get_condition_collections(self, train_id):
  61. """Get default condition_collections"""
  62. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  63. condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
  64. log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
  65. return self.condition_mgr.get_all_collections(condition_context)
  66. def set_recommended_watch_points(self, set_recommended, train_id):
  67. """set recommended watch points."""
  68. if not isinstance(set_recommended, bool):
  69. log.error("Bool param should be given for set_recommended")
  70. raise DebuggerParamValueError("Bool param should be given.")
  71. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  72. if metadata_stream.recommendation_confirmed:
  73. log.error("User has confirmed setting recommended watchpoints")
  74. raise DebuggerSetRecommendWatchpointsError()
  75. metadata_stream.recommendation_confirmed = True
  76. condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
  77. log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
  78. res = metadata_stream.get(['state', 'enable_recheck'])
  79. if set_recommended:
  80. res['id'] = self._add_recommended_watchpoints(condition_context)
  81. return res
  82. def _add_recommended_watchpoints(self, condition_context):
  83. """Add predefined watchpoints."""
  84. log.debug("Add predefined watchpoints.")
  85. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  86. watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context)
  87. watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  88. watch_points_ids = []
  89. for watchpoint in watchpoints:
  90. watch_points_id = watch_point_stream_handler.create_watchpoint(
  91. watch_condition=watchpoint.get_watch_condition_dict(),
  92. watch_nodes=watchpoint.watch_nodes,
  93. name=watchpoint.name,
  94. condition_mgr=self.condition_mgr
  95. )
  96. watch_points_ids.append(watch_points_id)
  97. return watch_points_ids
  98. def start(self):
  99. """Start server."""
  100. grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051
  101. host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
  102. hostname = "{}:{}".format(host, grpc_port)
  103. # initialize a grpc server
  104. grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
  105. grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager)
  106. grpc_server_manager.add_insecure_port(hostname)
  107. grpc_server_manager.start()
  108. my_server_thread = Thread(target=grpc_server_manager.wait_for_termination)
  109. # start grpc server
  110. my_server_thread.start()
  111. self.back_server = my_server_thread
  112. self.grpc_server_manager = grpc_server_manager
  113. # register stop server handler
  114. signal.signal(signal.SIGINT, self._stop_handler)
  115. log.info("Start grpc server %s", hostname)
  116. def _stop_handler(self, signum, frame):
  117. """Register stop server handler."""
  118. self.stop()
  119. log.debug("Deal with stop signal: %s, %s", signum, frame)
  120. def stop(self):
  121. """Stop debugger server."""
  122. log.info("Send terminate info to client.")
  123. self.control({'mode': 'terminate'})
  124. self.grpc_server_manager.stop(grace=None)
  125. self.back_server.join()
  126. log.info("Stop debugger server.")
  127. def poll_data(self, pos):
  128. """
  129. Get the pos-th data from DebuggerCache.
  130. Args:
  131. pos (int): The index of data.
  132. Returns:
  133. dict, the data to be updated.
  134. """
  135. if not isinstance(pos, str):
  136. log.error("Pos should be string. Received: %s", pos)
  137. raise DebuggerParamValueError("Pos should be string.")
  138. reply = self.cache_store.get_data(pos)
  139. return reply
  140. def search(self, filter_condition):
  141. """
  142. Search for single node in graph.
  143. Args:
  144. filter_condition (dict): Filter condition.
  145. - name (str): The name pattern.
  146. - graph_name (str): The graph name.
  147. - watch_point_id (int): The id of watchpoint. Default: 0.
  148. - node_category (str): The node_category. Default: None
  149. Returns:
  150. dict, the searched nodes.
  151. """
  152. log.info("receive search request with filter_condition: %s", filter_condition)
  153. # validate watchpoint id
  154. watch_point_id = filter_condition.pop('watch_point_id', 0)
  155. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  156. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  157. # validate and update graph name
  158. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  159. graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
  160. filter_condition['graph_name'] = graph_name
  161. # get searched graph
  162. graph = graph_stream.search_nodes(filter_condition)
  163. # add watched label to graph
  164. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name)
  165. return graph
  166. def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
  167. """
  168. Get tensor comparisons data for given name, detail, shape and tolerance.
  169. Args:
  170. name (str): The name of tensor for ui.
  171. detail (str): Specify which data to query. Current available value is 'data' which means
  172. concrete tensor data. Histogram or unique count can be supported in the future.
  173. shape (str): Specify concrete dimensions of shape.
  174. tolerance (str): Specify tolerance of difference between current step tensor and previous
  175. step tensor. Default value is 0.
  176. Raises:
  177. DebuggerParamValueError, If node type is not parameter or value of detail is not support.
  178. DebuggerCompareTensorError, If MindSpore is not in waiting state.
  179. Returns:
  180. dict, the retrieved data.
  181. """
  182. if self.cache_store.get_stream_handler(
  183. Streams.METADATA).state != ServerStatus.WAITING.value:
  184. log.error("Failed to compare tensors as the MindSpore is not in waiting state.")
  185. raise DebuggerCompareTensorError(
  186. "Failed to compare tensors as the MindSpore is not in waiting state."
  187. )
  188. self.validate_tensor_param(name, detail)
  189. # Limit to query max two dimensions for tensor in table view.
  190. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  191. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  192. tolerance = to_float(tolerance, 'tolerance')
  193. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  194. if node_type == NodeTypeEnum.PARAMETER.value:
  195. reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
  196. else:
  197. raise DebuggerParamValueError(
  198. "The node type must be parameter, but got {}.".format(node_type))
  199. return reply
  200. def retrieve(self, mode, filter_condition=None):
  201. """
  202. Retrieve data according to mode and params.
  203. Args:
  204. mode (str): The type of info message.
  205. filter_condition (dict): The filter condition.
  206. Returns:
  207. dict, the retrieved data.
  208. """
  209. log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
  210. filter_condition)
  211. mode_mapping = {
  212. 'all': self._retrieve_all,
  213. 'node': self._retrieve_node,
  214. 'watchpoint': self._retrieve_watchpoint,
  215. }
  216. # validate param <mode>
  217. if mode not in mode_mapping.keys():
  218. log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
  219. "'watchpoint_hit'], but got %s.", mode_mapping)
  220. raise DebuggerParamValueError("Invalid mode.")
  221. # validate backend status
  222. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  223. if metadata_stream.state == ServerStatus.PENDING.value:
  224. log.info("The backend is in pending status.")
  225. return metadata_stream.get()
  226. filter_condition = {} if filter_condition is None else filter_condition
  227. reply = mode_mapping[mode](filter_condition)
  228. return reply
  229. def _retrieve_all(self, filter_condition=None):
  230. """Retrieve metadata, root graph and watchpoint list."""
  231. if filter_condition:
  232. log.error("No filter condition required for retrieve all request.")
  233. raise DebuggerParamTypeError("filter_condition should be empty.")
  234. self.cache_store.clean_data()
  235. log.info("Clean data queue cache when retrieve all request.")
  236. result = {}
  237. for stream in [Streams.METADATA, Streams.GRAPH]:
  238. sub_res = self.cache_store.get_stream_handler(stream).get()
  239. result.update(sub_res)
  240. sub_res = self._hide_parameters_for_ui()
  241. result.update(sub_res)
  242. return result
  243. def _retrieve_node(self, filter_condition):
  244. """
  245. Retrieve node info.
  246. Args:
  247. filter_condition (dict): Filter condition.
  248. - name (str): The name of single node.
  249. - graph_name (str): The relative graph_name of the node.
  250. - single_node (bool): If False, return the sub-layer of single node. If True, return
  251. the node list from root node to single node.
  252. - watch_point_id (int): The id of watchpoint.
  253. Returns:
  254. dict, reply with graph.
  255. """
  256. log.debug("Retrieve node %s.", filter_condition)
  257. # validate node name
  258. node_name = filter_condition.get('name')
  259. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  260. graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
  261. if node_name:
  262. # validate node name
  263. graph_stream.get_node_type(node_name, graph_name)
  264. filter_condition['single_node'] = bool(filter_condition.get('single_node'))
  265. filter_condition['graph_name'] = graph_name
  266. reply = self._get_nodes_info(filter_condition)
  267. return reply
  268. def _get_nodes_info(self, filter_condition):
  269. """
  270. Get nodes info.
  271. Args:
  272. filter_condition (dict): The filter condition.
  273. - name (str): The node name.
  274. - graph_name (str): The relative graph_name of the node.
  275. - single_node (bool): If False, return the sub-layer of single node. If True, return
  276. the node list from root node to single node.
  277. - watch_point_id (int): The id of watchpoint.
  278. Returns:
  279. dict, reply with graph.
  280. """
  281. # validate watch_point_id
  282. watch_point_id = filter_condition.get('watch_point_id', 0)
  283. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  284. watchpoint_stream.validate_watchpoint_id(watch_point_id)
  285. # get graph
  286. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  287. reply = graph_stream.get(filter_condition)
  288. graph = reply.get('graph')
  289. # add watched label to graph
  290. watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'))
  291. return reply
  292. def retrieve_tensor_history(self, node_name, graph_name=None):
  293. """
  294. Retrieve tensor history for leaf node.
  295. Args:
  296. node_name (str): The name of leaf node.
  297. graph_name (str): The graph name. Default: None.
  298. Returns:
  299. dict, the tensor history and metadata.
  300. """
  301. log.info("Retrieve tensor history for node: %s.", node_name)
  302. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  303. if metadata_stream.state == ServerStatus.PENDING.value:
  304. log.info("The backend is in pending status.")
  305. return metadata_stream.get(['state', 'step'])
  306. res = self._get_tensor_history(node_name, graph_name)
  307. return res
  308. def _get_tensor_history(self, node_name, graph_name=None):
  309. """
  310. Get tensor history for single node.
  311. Args:
  312. node_name (str): The name of leaf node.
  313. graph_name (str): The graph name. Default: None.
  314. Returns:
  315. dict, the tensor history and metadata.
  316. """
  317. # get basic tensor history
  318. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  319. tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
  320. # add tensor value for tensor history
  321. self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name)
  322. # add hit label for tensor history
  323. watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  324. watchpoint_hit_stream.update_tensor_history(tensor_history)
  325. # add metadata
  326. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step'])
  327. tensor_history.update(metadata)
  328. return tensor_history
  329. def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name):
  330. """
  331. Add tensor value for_tensor_history and send ViewCMD if tensor value missed.
  332. Args:
  333. tensor_history (list[dict]): A list of tensor info, including name and type.
  334. node_name (str): The UI node name.
  335. graph_name (str): The graph name. Default: None.
  336. Returns:
  337. dict, the tensor info.
  338. """
  339. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  340. missed_tensors = tensor_stream.update_tensor_history(tensor_history)
  341. if missed_tensors:
  342. view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
  343. self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name})
  344. log.debug("Send view cmd.")
  345. def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False):
  346. """Retrieve the tensor value."""
  347. log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
  348. self.validate_tensor_param(name, detail)
  349. # Limit to query max two dimensions for tensor in table view.
  350. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
  351. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name)
  352. reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
  353. {'name': tensor_name,
  354. 'node_type': node_type,
  355. 'shape': parsed_shape,
  356. 'prev': prev}
  357. )
  358. reply['tensor_value']['name'] = name
  359. return reply
  360. def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None):
  361. """
  362. Get inner tensor name and type by UI name.
  363. Args:
  364. name (str): Node name shown in UI.
  365. graph_name (Union[str, None]): The graph name, default is: None.
  366. Returns:
  367. str, full name of tensor.
  368. str, node type of tensor.
  369. """
  370. node_name, slot = name.rsplit(':', 1)
  371. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  372. graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name)
  373. node_type = graph_stream.get_node_type(node_name, graph_name)
  374. full_name = graph_stream.get_full_name(node_name, graph_name)
  375. tensor_name = full_name + ':' + slot
  376. return node_type, tensor_name
  377. @staticmethod
  378. def validate_tensor_param(name, detail):
  379. """Validate params for retrieve tensor request."""
  380. # validate name
  381. if not isinstance(name, str) or ':' not in name:
  382. log.error("Invalid tensor name. Received: %s", name)
  383. raise DebuggerParamValueError("Invalid tensor name.")
  384. # validate data
  385. if detail != 'data':
  386. log.error("Invalid detail value. Received: %s", detail)
  387. raise DebuggerParamValueError("Invalid detail value.")
  388. def _retrieve_watchpoint(self, filter_condition):
  389. """
  390. Retrieve watchpoint.
  391. Args:
  392. filter_condition (dict): Filter condition.
  393. - watch_point_id (int): The id of watchpoint. If not given, return all watchpoints.
  394. - name (str): The name of single node.
  395. - single_node (bool): If False, return the sub-layer of single node. If True, return
  396. the node list from root node to single node.
  397. Returns:
  398. dict, watch point list or relative graph.
  399. """
  400. watchpoint_id = filter_condition.get('watch_point_id', 0)
  401. if not watchpoint_id:
  402. reply = self._hide_parameters_for_ui()
  403. log.debug("Get condition of watchpoints.")
  404. else:
  405. reply = self._retrieve_node(filter_condition)
  406. log.debug("Get graph of %d-th watchpoint.", watchpoint_id)
  407. return reply
  408. def search_watchpoint_hits(self, group_condition):
  409. """
  410. Retrieve watchpoint hit.
  411. Args:
  412. group_condition (dict): Filter condition.
  413. - limit (int): The limit of each page.
  414. - offset (int): The offset of current page.
  415. - node_name (str): The retrieved node name.
  416. - graph_name (str): The retrieved graph name.
  417. Returns:
  418. dict, watch point list or relative graph.
  419. """
  420. if not isinstance(group_condition, dict):
  421. log.error("Group condition for watchpoint-hits request should be a dict")
  422. raise DebuggerParamTypeError("Group condition for watchpoint-hits request should be a dict")
  423. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  424. if metadata_stream.state == ServerStatus.PENDING.value:
  425. log.info("The backend is in pending status.")
  426. return metadata_stream.get()
  427. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition)
  428. reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
  429. return reply
  430. def create_watchpoint(self, params):
  431. """
  432. Create watchpoint.
  433. Args:
  434. params (dict): Params for create watchpoint.
  435. - watch_condition (dict): The watch condition. The format is like:
  436. {
  437. "id": "tensor_too_large",
  438. "params": [
  439. {
  440. "name": "abs_mean_gt",
  441. "value": 1.1
  442. }
  443. ]
  444. }
  445. - id (str): Id of condition.
  446. - params (list[dict]): The list of param for this condition.
  447. - watch_nodes (list[str]): The list of node names.
  448. - watch_point_id (int): The id of watchpoint.
  449. - search_pattern (dict): The search pattern.
  450. - graph_name (str): The relative graph_name of the watched node.
  451. Returns:
  452. dict, the id of new watchpoint and metadata info.
  453. """
  454. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  455. return watchpoint_opt.create_watchpoint(params)
  456. def update_watchpoint(self, params):
  457. """
  458. Update watchpoint.
  459. Args:
  460. params (dict): Params for update watchpoint.
  461. - watch_point_id (int): The id of watchpoint.
  462. - watch_nodes (list[str]): The list of node names.
  463. - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
  464. 1 for add nodes to watch nodes.
  465. - search_pattern (dict): The search pattern.
  466. - graph_name (str): The relative graph_name of the watched node.
  467. Returns:
  468. dict, the metadata info.
  469. """
  470. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  471. return watchpoint_opt.update_watchpoint(params)
  472. def delete_watchpoint(self, watch_point_id=None):
  473. """
  474. Delete watchpoint.
  475. Args:
  476. watch_point_id (Union[None, int]): The id of watchpoint.
  477. If None, delete all watchpoints. Default: None.
  478. Returns:
  479. dict, the metadata info.
  480. """
  481. watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
  482. return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id)
  483. @try_except
  484. def control(self, params=None):
  485. """
  486. Control the training process.
  487. Args:
  488. params (dict): The control params.
  489. - mode (str): Acceptable control command, including `continue`,
  490. `pause` and `terminate`.
  491. - level (str): The control granularity, `node` level or `step` level.
  492. Default: `step`.
  493. - steps (int): Specify the steps that training should run.
  494. Used when `level` is `step`.
  495. - name (str): Specify the name of the node. Used when `level` is `node`.
  496. - graph_name (str): The graph name.
  497. Returns:
  498. dict, the response.
  499. """
  500. log.info("Receive control request: %s.", params)
  501. mode = params.pop('mode', None) if params else None
  502. training_controller = TrainingControlOperator(self.cache_store)
  503. training_controller.validate_mode(mode)
  504. return training_controller.control(mode, params)
  505. def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False):
  506. """
  507. Get the graph of the next node according to node_name.
  508. Args:
  509. node_name (str): The name of current chosen leaf node.
  510. graph_name (str): The graph name.
  511. ascend (bool): If True, traverse the input nodes;
  512. If False, traverse the output nodes. Default is True.
  513. Returns:
  514. dict, the next node information.
  515. """
  516. log.info("Retrieve node <%s> by bfs, `ascend` is :%s",
  517. node_name, ascend)
  518. reply = {}
  519. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  520. graph_name = graph_stream.validate_graph_name(graph_name)
  521. next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
  522. # no next node
  523. if next_node_name is None:
  524. return reply
  525. # add graph and tensor history for next node
  526. filter_condition = {
  527. 'name': next_node_name,
  528. 'graph_name': graph_name,
  529. 'single_node': True
  530. }
  531. search_graph = self._get_nodes_info(filter_condition)
  532. reply = {'name': next_node_name}
  533. reply.update(search_graph)
  534. return reply
  535. @try_except
  536. def recheck(self):
  537. """
  538. Recheck all watchpoints.
  539. Returns:
  540. dict, metadata info.
  541. """
  542. return TrainingControlOperator(self.cache_store).recheck()
  543. def retrieve_tensor_graph(self, tensor_name, graph_name):
  544. """
  545. Retrieve tensor graph.
  546. Args:
  547. tensor_name (str): The tensor name from UI.
  548. graph_name (str): The graph name.
  549. Returns:
  550. dict, tensor graph object.
  551. """
  552. if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
  553. log.error("Failed to get tensor graph the MindSpore is not in waiting state.")
  554. raise DebuggerTensorGraphError
  555. log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name)
  556. tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name)
  557. return tensor_graph_ops
  558. def retrieve_tensor_hits(self, tensor_name, graph_name):
  559. """
  560. Retrieve tensor hit information.
  561. Args:
  562. tensor_name (str): The tensor name from UI.
  563. graph_name (str): The graph name.
  564. Returns:
  565. dict, tensor hit info.
  566. """
  567. if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
  568. log.error("Failed to get tensor hits as the MindSpore is not in waiting state.")
  569. raise DebuggerTensorHitError
  570. log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name)
  571. watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name)
  572. return {'watch_points': watch_points}
  573. def _hide_parameters_for_ui(self):
  574. """
  575. Hide some parameters on ui.
  576. Returns:
  577. dict, watch point list.
  578. """
  579. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
  580. watch_points = reply.get('watch_points')
  581. for i, watch_point in enumerate(watch_points):
  582. watch_condition = watch_point.get('watch_condition')
  583. parameters = watch_condition.get('params')
  584. watch_condition_id = watch_condition.get('id')
  585. mgr_condition = self.condition_mgr.get_condition(watch_condition_id)
  586. ui_watch_condition = []
  587. for param in parameters:
  588. parameter_definition = mgr_condition.get_parameter_definition(param['name'])
  589. if not parameter_definition.visible_on_ui:
  590. continue
  591. ui_watch_condition.append(param)
  592. reply['watch_points'][i]['watch_condition']['params'] = ui_watch_condition
  593. return reply