You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 31 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. from mindinsight.debugger.conditionmgr.condition import ValueTypeEnum
  17. from mindinsight.debugger.conditionmgr.condition import ParamTypeEnum
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  19. DebuggerParamTypeError
  20. from mindinsight.debugger.common.log import LOGGER as log
  21. from mindinsight.debugger.common.utils import is_scope_type
  22. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  23. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  24. WatchNodeTree
  25. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  26. RANGE_START = 'range_start_inclusive'
  27. RANGE_END = 'range_end_inclusive'
  28. class WatchpointHandler(StreamHandlerBase):
  29. """Watchpoint Handler."""
  30. def __init__(self):
  31. self._watchpoints = {}
  32. # list of ids of new created watchpoints
  33. self._created_watchpoints = []
  34. # list of SetCMD of watchpoints to be deleted
  35. self._deleted_watchpoints = []
  36. # dict of <id, Watchpoint> of watchpoints to be updated
  37. self._updated_watchpoints = {}
  38. # the collection of watched node full names, which have been sent to MindSpore
  39. self._latest_id = 0
  40. self._cache_set_cmd = {}
  41. # whether the watchpoint list has been changed since last step
  42. self._outdated = False
  43. def set_outdated(self):
  44. """"Set outdated as True."""
  45. self._outdated = True
  46. def put(self, value):
  47. """
  48. Put Watchpoint into watchpoint handler.
  49. Args:
  50. value (Watchpoint): The name of nodes that have been chosen.
  51. """
  52. new_id = value.watchpoint_id
  53. self._watchpoints[new_id] = value
  54. self._created_watchpoints.append(new_id)
  55. self._updated_watchpoints[new_id] = value
  56. self._latest_id = new_id
  57. log.debug("Put watchpoint %d into cache.", new_id)
  58. def sync_set_cmd(self, set_cmds):
  59. """Clean temp watchpoints."""
  60. self._outdated = False
  61. self._created_watchpoints = []
  62. self._deleted_watchpoints = []
  63. self._updated_watchpoints = {}
  64. for set_cmd in set_cmds:
  65. self._cache_set_cmd[set_cmd.id] = set_cmd
  66. def clean_cache_set_cmd(self, set_cmd):
  67. """Clean cache set command."""
  68. self._cache_set_cmd.pop(set_cmd.id, None)
  69. def get_watchpoint_by_id(self, watchpoint_id):
  70. """Get watchpoint by watchpoint id."""
  71. res = self.get(watchpoint_id)
  72. watchpoint = res.get('watch_points')[0]
  73. return watchpoint
  74. def get(self, filter_condition=None):
  75. """
  76. Get the watchpoints.
  77. Args:
  78. filter_condition (Union[None, int]): The filter conditions. Get watchpoint by
  79. id. If None, return all watchpoint. Default: None.
  80. Returns:
  81. dict, the watchpoint list.
  82. """
  83. reply = []
  84. if not filter_condition:
  85. # get watch condition list
  86. for _, watchpoint in self._watchpoints.items():
  87. watchpoint_info = watchpoint.get_watch_condition_info()
  88. reply.append(watchpoint_info)
  89. else:
  90. self.validate_watchpoint_id(filter_condition)
  91. reply = [self._watchpoints.get(filter_condition)]
  92. log.debug("get the watch points with filter_condition:%s", filter_condition)
  93. return {'watch_points': reply}
  94. def get_pending_commands(self, multi_card_graph_stream):
  95. """
  96. Get all watchpoint in SetCMD proto format.
  97. Args:
  98. multi_card_graph_stream (MultiCardGraphHandler): Multi card graph handler.
  99. Returns:
  100. list[SetCMD], updated watchpoint to be sent to MindSpore.
  101. """
  102. newly_set_cmds = []
  103. for _, watchpoint in self._updated_watchpoints.items():
  104. # construct set command with leaf nodes
  105. watch_nodes_for_devices = watchpoint.get_watch_nodes()
  106. leaf_watch_nodes_for_devices = {}
  107. for rank_id, watch_nodes in watch_nodes_for_devices.items():
  108. graph_stream = multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id)
  109. leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
  110. leaf_watch_nodes_for_devices[rank_id] = leaf_watch_nodes
  111. newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes_for_devices))
  112. newly_set_cmds.extend(self._deleted_watchpoints)
  113. self.sync_set_cmd(newly_set_cmds)
  114. return list(self._cache_set_cmd.values())
  115. @staticmethod
  116. def _expand_to_leaf_nodes(graph_stream, watch_nodes):
  117. """
  118. Get all leaf node basic info according to watch nodes.
  119. Args:
  120. graph_stream (GraphHandler): Graph handler.
  121. watch_nodes (list[NodeBasicInfo]): The list of watch node basic infos.
  122. Returns:
  123. list[NodeBasicInfo], expanded leaf basic node infos.
  124. """
  125. leaf_watch_nodes = []
  126. for node in watch_nodes:
  127. if is_scope_type(node.type):
  128. pure_node_name = ''
  129. if len(node.name.split('/')) > 1:
  130. graph_name, pure_node_name = node.name.split('/', 1)
  131. else:
  132. graph_name = node.name
  133. search_node_infos = graph_stream.get_node_basic_info_by_scope(pure_node_name, graph_name=graph_name)
  134. leaf_watch_nodes.extend(search_node_infos)
  135. else:
  136. leaf_watch_nodes.append(node)
  137. return leaf_watch_nodes
  138. def is_recheckable(self):
  139. """
  140. Check if current status is able to recheck.
  141. Returns:
  142. bool, if enable to recheck.
  143. """
  144. return self._outdated
  145. def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None, rank_id=0):
  146. """
  147. set watch nodes for graph.
  148. Args:
  149. graph (dict): The graph with list of nodes.
  150. graph_stream (GraphHandler): The graph handler.
  151. watch_point_id (int): The id of watchpoint.
  152. graph_name (str): The graph name.
  153. rank_id (int): The rank id.
  154. """
  155. if not (watch_point_id and graph):
  156. return
  157. log.debug("add watch flags")
  158. watchpoint = self._watchpoints.get(watch_point_id)
  159. self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name, rank_id)
  160. def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None, rank_id=0):
  161. """Set watch status to graph."""
  162. if graph.get('children'):
  163. self._set_watch_status_recursively(
  164. graph.get('children'), graph_stream, watchpoint, graph_name, rank_id=0)
  165. if graph.get('nodes'):
  166. _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name, rank_id)
  167. def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name, rank_id=0):
  168. """
  169. Set watch state for nodes.
  170. Args:
  171. nodes (list[Node]): List of node info.
  172. Returns:
  173. int, the number of all watched nodes.
  174. """
  175. all_watched_num = 0
  176. valid_node_num = len(nodes)
  177. # initialize the state of current node.
  178. state = WatchNodeTree.NOT_WATCH
  179. for node in nodes:
  180. node_name = node.get('name')
  181. # search result could have `nodes` in nodes object
  182. if node.get('nodes'):
  183. flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name, rank_id)
  184. else:
  185. full_name = graph_stream.get_full_name(node_name, graph_name)
  186. new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name])
  187. flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name, rank_id)
  188. node['watched'] = flag
  189. if flag == WatchNodeTree.NOT_WATCH:
  190. continue
  191. state = WatchNodeTree.PARTIAL_WATCH
  192. if flag == WatchNodeTree.INVALID:
  193. valid_node_num -= 1
  194. elif flag == WatchNodeTree.TOTAL_WATCH:
  195. all_watched_num += 1
  196. # update the watch status of current node
  197. if not valid_node_num:
  198. state = WatchNodeTree.INVALID
  199. elif all_watched_num == valid_node_num:
  200. state = WatchNodeTree.TOTAL_WATCH
  201. return state
  202. def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None,
  203. device_amount=8):
  204. """
  205. Create watchpoint.
  206. Args:
  207. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  208. watch_condition (dict): The watch condition.
  209. "condition": {
  210. id: "tensor_too_large",
  211. "params": [
  212. {
  213. "name": "abs_mean_gt",
  214. "value": 1.1
  215. }
  216. ]
  217. }
  218. - id (str): Id of condition.
  219. - param (list[dict]): The list of param for this condition.
  220. watch_nodes (dict[list[NodeBasicInfo]]): The list of node basic info.
  221. watch_point_id (int): The id of watchpoint.
  222. name (str): The name of watchpoint.
  223. device_amount (int): The amount of devices.
  224. Returns:
  225. int, the new id of watchpoint.
  226. """
  227. validate_watch_condition(condition_mgr, watch_condition)
  228. watch_condition = set_default_param(condition_mgr, watch_condition)
  229. new_id = self._latest_id + 1
  230. watchpoint = Watchpoint(new_id, watch_condition, name)
  231. if watch_nodes:
  232. for rank_id, watch_nodes_for_device in watch_nodes.items():
  233. validate_rank_id(rank_id, device_amount)
  234. watchpoint.add_nodes(watch_nodes_for_device, rank_id)
  235. elif watch_point_id:
  236. self.validate_watchpoint_id(watch_point_id)
  237. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  238. self.put(watchpoint)
  239. self._outdated = True
  240. return new_id
  241. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False, rank_id=0):
  242. """
  243. Update watchpoint.
  244. Args:
  245. watch_point_id (int): The id of watchpoint.
  246. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  247. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  248. If True, add nodes to watch nodes. Default: False.
  249. rank_id (int): The rank id.
  250. """
  251. self.validate_watchpoint_id(watch_point_id)
  252. watchpoint = self._watchpoints.get(watch_point_id)
  253. if watched:
  254. watchpoint.add_nodes(watch_nodes, rank_id)
  255. else:
  256. watchpoint.remove_nodes(watch_nodes, rank_id)
  257. self._updated_watchpoints[watch_point_id] = watchpoint
  258. self._outdated = True
  259. log.debug("Update watchpoint %d in cache.", watch_point_id)
  260. def delete_watchpoint(self, watch_point_id=None):
  261. """
  262. Delete watchpoint.
  263. Args:
  264. watch_point_id (Union[None, int]): The id of watchpoint.
  265. If None, delete all watchpoints. Default: None.
  266. """
  267. if watch_point_id is None:
  268. watch_point_ids = [sub_id for sub_id, _ in self._watchpoints.items()]
  269. else:
  270. self.validate_watchpoint_id(watch_point_id)
  271. watch_point_ids = [watch_point_id]
  272. for single_id in watch_point_ids:
  273. self._delete_single_watchpoint(single_id)
  274. self._outdated = True
  275. def _delete_single_watchpoint(self, watch_point_id):
  276. """
  277. Delete single watchpoint.
  278. Args:
  279. watch_point_id (int): The id of watchpoint.
  280. """
  281. self._watchpoints.pop(watch_point_id)
  282. # if the watchpoint has not been created by MindSpore, clean the relative cache directly
  283. if watch_point_id in self._created_watchpoints:
  284. self._created_watchpoints.remove(watch_point_id)
  285. self._updated_watchpoints.pop(watch_point_id)
  286. log.debug("Cancel create watchpoint %d in cache.", watch_point_id)
  287. return
  288. set_cmd = SetCMD()
  289. set_cmd.id = watch_point_id
  290. set_cmd.delete = True
  291. self._deleted_watchpoints.append(set_cmd)
  292. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  293. def validate_watchpoint_id(self, watch_point_id):
  294. """Validate watchpoint id."""
  295. if not isinstance(watch_point_id, int):
  296. log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
  297. raise DebuggerParamTypeError("Watchpoint id should be int type.")
  298. if watch_point_id and watch_point_id not in self._watchpoints:
  299. log.error("Invalid watchpoint id: %d.", watch_point_id)
  300. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  301. class MultiCardWatchpointHitHandler:
  302. """Multi-card Watchpoint-hit Handler."""
  303. def __init__(self):
  304. self.watchpoint_hit_handlers = {0: WatchpointHitHandler()}
  305. def get_hit_handler_by_rank_id(self, rank_id=0):
  306. """Get handler by rank id."""
  307. if rank_id in self.watchpoint_hit_handlers:
  308. return self.watchpoint_hit_handlers.get(rank_id)
  309. log.error("There is no rank id %d.", rank_id)
  310. raise ValueError
  311. def put(self, value):
  312. """Put watchpoint hit into cache."""
  313. for rank_id, tensor_hit_values in value.items():
  314. if rank_id not in self.watchpoint_hit_handlers:
  315. self.watchpoint_hit_handlers[rank_id] = WatchpointHitHandler()
  316. cur_hit_handler = self.watchpoint_hit_handlers[rank_id]
  317. for tensor_hit_value in tensor_hit_values:
  318. cur_hit_handler.put(tensor_hit_value)
  319. def get(self, filter_condition=None, rank_id=0):
  320. """Get the graph of specific node for specific device."""
  321. if rank_id in self.watchpoint_hit_handlers:
  322. return self.watchpoint_hit_handlers.get(rank_id).get(filter_condition)
  323. log.error("There is no rank id %d.", rank_id)
  324. raise ValueError
  325. def update_tensor_history(self, tensor_history, rank_id):
  326. """
  327. Add hit flag to tensor history.
  328. Args:
  329. tensor_history (dict): The tensor history.
  330. rank_id (int): The rank id.
  331. """
  332. if rank_id in self.watchpoint_hit_handlers:
  333. self.watchpoint_hit_handlers[rank_id].update_tensor_history(tensor_history)
  334. else:
  335. for tensor_info in tensor_history.get('tensor_history'):
  336. tensor_info['is_hit'] = False
  337. def check_rank_id(self, rank_id):
  338. """check if has the rank id."""
  339. return rank_id in self.watchpoint_hit_handlers
  340. def clean(self):
  341. """Clean cache."""
  342. self.__init__()
  343. class WatchpointHitHandler(StreamHandlerBase):
  344. """Watchpoint hit handler."""
  345. def __init__(self):
  346. # dict of <ui node_name, dict of <slot, WatchpointHit>>,
  347. self._ordered_hits = []
  348. self._multi_graph_hits = {}
  349. @property
  350. def empty(self):
  351. """Whether the watchpoint hit is empty."""
  352. return not self._multi_graph_hits
  353. def put(self, value):
  354. """
  355. Put value into watchpoint hit cache. Called by grpc server.
  356. Args:
  357. value (dict): The watchpoint hit info.
  358. - tensor_proto (TensorProto): The message about hit tensor.
  359. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  360. - node_name (str): The UI node name.
  361. - graph_name (str): The graph name.
  362. - error_code (int): The code of errors.
  363. """
  364. watchpoint_hit = WatchpointHit(
  365. tensor_proto=value.get('tensor_proto'),
  366. watchpoint=value.get('watchpoint'),
  367. node_name=value.get('node_name'),
  368. graph_name=value.get('graph_name')
  369. )
  370. if 'error_code' in value.keys():
  371. watchpoint_hit.error_code = value.get('error_code')
  372. # get all hit watchpoints according to node name ans tensor slot
  373. watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.graph_name, watchpoint_hit.node_name,
  374. watchpoint_hit.slot)
  375. if watchpoint_hit not in watchpoint_hits:
  376. watchpoint_hits.append(watchpoint_hit)
  377. def _get_watchpoints_by_tensor_name(self, graph_name, node_name, slot):
  378. """
  379. Get hit tensors according to ui node name and slot.
  380. Args:
  381. node_name (str): The node name.
  382. slot (str): The tensor slot.
  383. Returns:
  384. list, list of watchpoints.
  385. """
  386. index = self._multi_graph_hits.get((graph_name, node_name))
  387. if index is None:
  388. hit_node = {}
  389. self._ordered_hits.append(hit_node)
  390. index = len(self._ordered_hits) - 1
  391. self._multi_graph_hits[(graph_name, node_name)] = index
  392. hit_node = self._ordered_hits[index]
  393. hit_tensors = hit_node.get(slot)
  394. if hit_tensors is None:
  395. hit_tensors = []
  396. hit_node[slot] = hit_tensors
  397. return hit_tensors
  398. def get(self, filter_condition=None):
  399. """
  400. Get watchpoint hit list.
  401. Args:
  402. filter_condition (str): Get the watchpoint hit according to specified node name.
  403. If not given, get all watchpoint hits. Default: None.
  404. Returns:
  405. dict, the watchpoint hit list.
  406. """
  407. reply = None
  408. if filter_condition is None:
  409. log.debug("Get all watchpoint hit list.")
  410. reply = self.get_watchpoint_hits()
  411. else:
  412. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  413. index = self._multi_graph_hits.get(("", filter_condition))
  414. if index is not None:
  415. reply = self._ordered_hits[index]
  416. return reply
  417. def group_by(self, group_condition):
  418. """
  419. Return the watchpoint hits by group condition.
  420. Args:
  421. group_condition (dict): The group conditions.
  422. - limit (int): The limit number of watchpoint hits each page.
  423. - offset (int): The page offset.
  424. - node_name (str): The node name.
  425. - graph_name (str): The graph name.
  426. Returns:
  427. dict, the watchpoint hit list.
  428. """
  429. node_name = group_condition.get('node_name')
  430. # get all watchpoint hit list
  431. if node_name is None:
  432. reply = self._get_by_offset(group_condition)
  433. else:
  434. reply = self._get_by_name(group_condition)
  435. return reply
  436. def _get_by_offset(self, group_condition):
  437. """Return the list of watchpoint hits on the offset page."""
  438. limit = group_condition.get('limit')
  439. offset = group_condition.get('offset')
  440. if not isinstance(limit, int) or not isinstance(offset, int):
  441. log.error("Param limit or offset is not a integer")
  442. raise DebuggerParamValueError("Param limit or offset is not a integer")
  443. watch_point_hits = []
  444. total = len(self._ordered_hits)
  445. if limit * offset >= total and offset != 0:
  446. log.error("Param offset out of bounds")
  447. raise DebuggerParamValueError("Param offset out of bounds")
  448. if total == 0:
  449. return {}
  450. for watchpoint_hits in self._ordered_hits[(limit * offset): (limit * (offset + 1))]:
  451. self._get_tensors(watchpoint_hits, watch_point_hits)
  452. return {
  453. 'watch_point_hits': watch_point_hits,
  454. 'offset': offset,
  455. 'total': total
  456. }
  457. def _get_by_name(self, group_condition):
  458. """Return the list of watchpoint hits by the group condition."""
  459. limit = group_condition.get('limit')
  460. if not isinstance(limit, int) or limit == 0:
  461. log.error("Param limit is 0 or not a integer")
  462. raise DebuggerParamValueError("Param limit is 0 or not a integer")
  463. index = self._multi_graph_hits.get((group_condition.get('graph_name'), group_condition.get('node_name')))
  464. if index is not None:
  465. group_condition['offset'] = index//limit
  466. return self._get_by_offset(group_condition)
  467. return {}
  468. def get_watchpoint_hits(self):
  469. """Return the list of watchpoint hits."""
  470. watch_point_hits = []
  471. for watchpoint_hits in self._ordered_hits:
  472. self._get_tensors(watchpoint_hits, watch_point_hits)
  473. return {'watch_point_hits': watch_point_hits}
  474. def _get_tensors(self, watchpoint_hits, watch_point_hits):
  475. """Get the tensors info for the watchpoint_hits."""
  476. tensors = []
  477. graph_name = None
  478. node_name = None
  479. for slot, tensor_hits in watchpoint_hits.items():
  480. if graph_name is None:
  481. graph_name = tensor_hits[0].graph_name
  482. if node_name is None:
  483. node_name = tensor_hits[0].node_name
  484. tensor_info = self._get_tensor_hit_info(slot, tensor_hits)
  485. tensors.append(tensor_info)
  486. watch_point_hits.append({
  487. 'node_name': node_name,
  488. 'tensors': tensors,
  489. 'graph_name': graph_name
  490. })
  491. @staticmethod
  492. def _get_tensor_hit_info(slot, tensor_hits):
  493. """
  494. Get watchpoint hit info of specified tensor.
  495. Args:
  496. slot (str): Slot id.
  497. tensor_hits (list): A list of watchpoint hit objects that the tensor hit.
  498. Returns:
  499. dict, tensor hit info.
  500. """
  501. res = {}
  502. watch_points = []
  503. for tensor_hit in tensor_hits:
  504. error_code = tensor_hit.error_code
  505. error_list = _get_error_list(error_code)
  506. watchpoint = tensor_hit.watchpoint
  507. watchpoint['error_code'] = error_code
  508. watchpoint['error_list'] = error_list
  509. watch_points.append(watchpoint)
  510. if watch_points:
  511. watch_points.sort(key=lambda watch_point: watch_point.get('id'))
  512. res = {
  513. 'slot': slot,
  514. 'watch_points': watch_points
  515. }
  516. return res
  517. def _is_tensor_hit(self, tensor_name, graph_name):
  518. """
  519. Check if the tensor is record in hit cache.
  520. Args:
  521. tensor_name (str): The name of ui tensor name.
  522. graph_name (str): The name of ui graph name
  523. Returns:
  524. bool, if the tensor is hit.
  525. """
  526. node_name, slot = tensor_name.rsplit(':', 1)
  527. index = self._multi_graph_hits.get((graph_name, node_name))
  528. if index is not None:
  529. watchpoint_hits = self._ordered_hits[index].get(slot)
  530. return bool(watchpoint_hits)
  531. return False
  532. def update_tensor_history(self, tensor_history):
  533. """
  534. Add hit flag to tensor history.
  535. Args:
  536. tensor_history (dict): The tensor history.
  537. """
  538. if not self._multi_graph_hits:
  539. return
  540. # add hit tensor names to `tensor_names`
  541. for tensor_info in tensor_history.get('tensor_history'):
  542. tensor_name = tensor_info['name']
  543. graph_name = tensor_info['graph_name']
  544. hit_flag = self._is_tensor_hit(tensor_name, graph_name)
  545. tensor_info['is_hit'] = hit_flag
  546. def get_tensor_hit_infos(self, tensor_name, graph_name):
  547. """
  548. Get all hit information of a tensor.
  549. Args:
  550. tensor_name (str): Tensor name showed on UI.
  551. Returns:
  552. dict, tensor hit info.
  553. """
  554. tensor_hit_info = {}
  555. if self._is_tensor_hit(tensor_name, graph_name):
  556. node_name, slot = tensor_name.rsplit(':', 1)
  557. tensor_hits = self._get_watchpoints_by_tensor_name(graph_name, node_name, slot)
  558. tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits)
  559. return tensor_hit_info
  560. def validate_watch_condition(condition_mgr, watch_condition):
  561. """Validate watch condition."""
  562. if not isinstance(watch_condition, dict):
  563. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  564. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  565. # validate condition_id
  566. condition_id = watch_condition.get('id')
  567. if condition_id not in condition_mgr.conditions.keys():
  568. log.error("Invalid watch condition. Acceptable values are <%s>. %s received.",
  569. str(condition_mgr.conditions.keys()), condition_id)
  570. raise DebuggerParamValueError("Invalid watch condition value.")
  571. # validate param
  572. validate_watch_condition_params(condition_mgr, watch_condition)
  573. def validate_watch_condition_params(condition_mgr, watch_condition):
  574. """
  575. Validate watch condition parameters.
  576. Args:
  577. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  578. watch_condition (dict): Watch condition.
  579. - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.
  580. - param (list): Condition value. Should be given for comparison condition. The value
  581. will be translated to np.float32.
  582. """
  583. condition_id = watch_condition.get('id')
  584. params = watch_condition.get('params')
  585. condition = condition_mgr.get_condition(condition_id)
  586. if condition_id in condition_mgr.get_no_param_condition():
  587. if params:
  588. log.error("No param is expected for %s condition", condition_id)
  589. raise DebuggerParamValueError("No param is expected.")
  590. return
  591. check_param_num = 0
  592. support_params = set()
  593. defined_support_params = set()
  594. range_param = {RANGE_START: None, RANGE_END: None}
  595. for param in params:
  596. if len(param) > 2:
  597. log.error("Invalid param keys for condition: %s", condition_id)
  598. raise DebuggerParamValueError("Invalid param keys.")
  599. condition_param_name = param.get("name")
  600. if condition_param_name not in condition.names:
  601. log.error("Invalid name of parameter for condition: %s, available values: %s",
  602. condition_id, condition.names)
  603. raise DebuggerParamValueError("Invalid name of parameter.")
  604. condition_param = condition.get_parameter_definition(condition_param_name)
  605. validate_param_type(condition_id, condition_param, param)
  606. if not condition_param.is_valid(param.get("value")):
  607. log.error("Param %s out of range for condition: %s", condition_param_name, condition_id)
  608. raise DebuggerParamValueError("Parameter out of range.")
  609. if condition_param.param_type == ParamTypeEnum.CHECK_PARAM.value:
  610. if condition_param.required_params:
  611. defined_support_params = set(condition_param.required_params)
  612. check_param_num += 1
  613. else:
  614. support_params.add(condition_param.name)
  615. if condition_param_name in range_param:
  616. range_param[condition_param_name] = param.get("value")
  617. if check_param_num > 1:
  618. log.error("Multiple check params for condition: %s", condition_id)
  619. raise DebuggerParamValueError("Multiple check params.")
  620. if support_params != defined_support_params:
  621. log.error("Invalid support params for condition: %s", condition_id)
  622. raise DebuggerParamValueError("Invalid support params.")
  623. if range_param.get(RANGE_START) is not None and \
  624. range_param.get(RANGE_END) is not None and range_param.get(RANGE_START) > \
  625. range_param.get(RANGE_END):
  626. log.error("Invalid support params for condition: %s", condition_id)
  627. raise DebuggerParamValueError("Invalid support params.")
  628. def validate_param_type(condition_id, condition_param, param):
  629. """
  630. Validate parameter type.
  631. Args:
  632. condition_id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.
  633. condition_param (ConditionParameter): Condition Parameter object.
  634. param (dict): Condition parameter value.
  635. """
  636. if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
  637. and not isinstance(param.get("value"), (float, int)):
  638. log.error("Number param should be given for condition: %s", condition_id)
  639. raise DebuggerParamValueError("Number param should be given.")
  640. if condition_param.type.name == ValueTypeEnum.BOOL.name \
  641. and not isinstance(param.get("value"), bool):
  642. log.error("Bool param should be given for condition: %s", condition_id)
  643. raise DebuggerParamValueError("Bool param should be given.")
  644. def set_default_param(condition_mgr, watch_condition):
  645. """
  646. Set default param.
  647. Args:
  648. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  649. watch_condition (dict): The watch condition.
  650. "condition": {
  651. id: "tensor_too_large",
  652. "params": [
  653. {
  654. "name": "abs_mean_gt",
  655. "value": 1.1
  656. }
  657. ]
  658. }
  659. - id (str): Id of condition.
  660. - param (list[dict]): The list of param for this condition.
  661. Returns:
  662. dict, the new watch_condition.
  663. """
  664. condition_id = watch_condition.get('id')
  665. condition = condition_mgr.get_condition(condition_id)
  666. for param in condition.parameters:
  667. if not param.visible_on_ui and not param.support_disable:
  668. watch_condition["params"].append({
  669. "name": param.name,
  670. "value": param.default_value
  671. })
  672. watch_condition["abbr"] = condition.abbr
  673. return watch_condition
  674. def _get_error_list(error_code):
  675. """
  676. Get error list.
  677. Args:
  678. error_code (int): The code of errors.
  679. Returns:
  680. list, the error list.
  681. """
  682. all_error_list = ["nan", "inf", "no_prev_tensor"]
  683. error_list = []
  684. for i, error_str in enumerate(all_error_list):
  685. error = (error_code >> i) & 1
  686. if error == 1:
  687. error_list.append(error_str)
  688. return error_list
  689. def validate_rank_id(rank_id, device_amount):
  690. """validate rank id"""
  691. if rank_id >= device_amount:
  692. log.debug("The rank id %d over device amount.", rank_id)