You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 24 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. from mindinsight.debugger.conditionmgr.condition import ValueTypeEnum
  17. from mindinsight.debugger.conditionmgr.condition import ParamTypeEnum
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  19. DebuggerParamTypeError
  20. from mindinsight.debugger.common.log import LOGGER as log
  21. from mindinsight.debugger.common.utils import is_scope_type
  22. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  23. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  24. WatchNodeTree
  25. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  26. class WatchpointHandler(StreamHandlerBase):
  27. """Watchpoint Handler."""
  28. def __init__(self):
  29. self._watchpoints = {}
  30. # list of ids of new created watchpoints
  31. self._created_watchpoints = []
  32. # list of SetCMD of watchpoints to be deleted
  33. self._deleted_watchpoints = []
  34. # dict of <id, Watchpoint> of watchpoints to be updated
  35. self._updated_watchpoints = {}
  36. # the collection of watched node full names, which have been sent to MindSpore
  37. self._latest_id = 0
  38. self._cache_set_cmd = {}
  39. # whether the watchpoint list has been changed since last step
  40. self._outdated = False
  41. def put(self, value):
  42. """
  43. Put Watchpoint into watchpoint handler.
  44. Args:
  45. value (Watchpoint): The name of nodes that have been chosen.
  46. """
  47. new_id = value.watchpoint_id
  48. self._watchpoints[new_id] = value
  49. self._created_watchpoints.append(new_id)
  50. self._updated_watchpoints[new_id] = value
  51. self._latest_id = new_id
  52. log.debug("Put watchpoint %d into cache.", new_id)
  53. def sync_set_cmd(self, set_cmds):
  54. """Clean temp watchpoints."""
  55. self._outdated = False
  56. self._created_watchpoints = []
  57. self._deleted_watchpoints = []
  58. self._updated_watchpoints = {}
  59. for set_cmd in set_cmds:
  60. self._cache_set_cmd[set_cmd.id] = set_cmd
  61. def clean_cache_set_cmd(self, set_cmd):
  62. """Clean cache set command."""
  63. self._cache_set_cmd.pop(set_cmd.id, None)
  64. def get_watchpoint_by_id(self, watchpoint_id):
  65. """Get watchpoint by watchpoint id."""
  66. res = self.get(watchpoint_id)
  67. watchpoint = res.get('watch_points')[0]
  68. return watchpoint
  69. def get(self, filter_condition=None):
  70. """
  71. Get the watchpoints.
  72. Args:
  73. filter_condition (Union[None, int]): The filter conditions. Get watchpoint by
  74. id. If None, return all watchpoint. Default: None.
  75. Returns:
  76. dict, the watchpoint list.
  77. """
  78. reply = []
  79. if not filter_condition:
  80. # get watch condition list
  81. for _, watchpoint in self._watchpoints.items():
  82. watchpoint_info = watchpoint.get_watch_condition_info()
  83. reply.append(watchpoint_info)
  84. else:
  85. self.validate_watchpoint_id(filter_condition)
  86. reply = [self._watchpoints.get(filter_condition)]
  87. log.debug("get the watch points with filter_condition:%s", filter_condition)
  88. return {'watch_points': reply}
  89. def get_pending_commands(self, graph_stream):
  90. """
  91. Get all watchpoint in SetCMD proto format.
  92. Args:
  93. graph_stream (GraphHandler): Graph handler.
  94. Returns:
  95. list[SetCMD], updated watchpoint to be sent to MindSpore.
  96. """
  97. res = []
  98. for _, watchpoint in self._updated_watchpoints.items():
  99. # construct set command with leaf nodes
  100. watch_nodes = watchpoint.get_watch_nodes()
  101. leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
  102. res.append(watchpoint.get_pending_cmd(leaf_watch_nodes))
  103. res.extend(self._deleted_watchpoints)
  104. for _, set_cmd in self._cache_set_cmd.items():
  105. res.append(set_cmd)
  106. return res
  107. @staticmethod
  108. def _expand_to_leaf_nodes(graph_stream, watch_nodes):
  109. """
  110. Get all leaf node basic info according to watch nodes.
  111. Args:
  112. graph_stream (GraphHandler): Graph handler.
  113. watch_nodes (list[NodeBasicInfo]): The list of watch node basic infos.
  114. Returns:
  115. list[NodeBasicInfo], expanded leaf basic node infos.
  116. """
  117. leaf_watch_nodes = []
  118. for node in watch_nodes:
  119. if is_scope_type(node.type):
  120. pure_node_name = None
  121. if len(node.name.split('/')) > 1:
  122. graph_name, pure_node_name = node.name.split('/', 1)
  123. else:
  124. graph_name = node.name
  125. search_node_infos = graph_stream.get_node_basic_info_by_scope(pure_node_name, graph_name=graph_name)
  126. leaf_watch_nodes.extend(search_node_infos)
  127. else:
  128. leaf_watch_nodes.append(node)
  129. return leaf_watch_nodes
  130. def is_recheckable(self):
  131. """
  132. Check if current status is able to recheck.
  133. Returns:
  134. bool, if enable to recheck.
  135. """
  136. return self._outdated
  137. def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None):
  138. """
  139. set watch nodes for graph.
  140. Args:
  141. graph (dict): The graph with list of nodes.
  142. graph_stream (GraphHandler): The graph handler.
  143. watch_point_id (int): The id of watchpoint.
  144. graph_name (str): The graph name.
  145. """
  146. if not (watch_point_id and graph):
  147. return
  148. log.debug("add watch flags")
  149. watchpoint = self._watchpoints.get(watch_point_id)
  150. self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name)
  151. def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None):
  152. """Set watch status to graph."""
  153. if graph.get('children'):
  154. self._set_watch_status_recursively(
  155. graph.get('children'), graph_stream, watchpoint, graph_name)
  156. if graph.get('nodes'):
  157. _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name)
  158. def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name):
  159. """
  160. Set watch state for nodes.
  161. Args:
  162. nodes (list[Node]): List of node info.
  163. Returns:
  164. int, the number of all watched nodes.
  165. """
  166. all_watched_num = 0
  167. valid_node_num = len(nodes)
  168. # initialize the state of current node.
  169. state = WatchNodeTree.NOT_WATCH
  170. for node in nodes:
  171. node_name = node.get('name')
  172. # search result could have `nodes` in nodes object
  173. if node.get('nodes'):
  174. flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name)
  175. else:
  176. full_name = graph_stream.get_full_name(node_name, graph_name)
  177. new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name])
  178. flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name)
  179. node['watched'] = flag
  180. if flag == WatchNodeTree.NOT_WATCH:
  181. continue
  182. state = WatchNodeTree.PARTIAL_WATCH
  183. if flag == WatchNodeTree.INVALID:
  184. valid_node_num -= 1
  185. elif flag == WatchNodeTree.TOTAL_WATCH:
  186. all_watched_num += 1
  187. # update the watch status of current node
  188. if not valid_node_num:
  189. state = WatchNodeTree.INVALID
  190. elif all_watched_num == valid_node_num:
  191. state = WatchNodeTree.TOTAL_WATCH
  192. return state
  193. def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None):
  194. """
  195. Create watchpoint.
  196. Args:
  197. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  198. watch_condition (dict): The watch condition.
  199. "condition": {
  200. id: "tensor_too_large",
  201. "params": [
  202. {
  203. "name": "abs_mean_gt",
  204. "value": 1.1
  205. }
  206. ]
  207. }
  208. - id (str): Id of condition.
  209. - param (list[dict]): The list of param for this condition.
  210. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  211. watch_point_id (int): The id of watchpoint.
  212. name (str): The name of watchpoint.
  213. Returns:
  214. int, the new id of watchpoint.
  215. """
  216. validate_watch_condition(condition_mgr, watch_condition)
  217. watch_condition = set_default_param(condition_mgr, watch_condition)
  218. new_id = self._latest_id + 1
  219. watchpoint = Watchpoint(new_id, watch_condition, name)
  220. if watch_nodes:
  221. watchpoint.add_nodes(watch_nodes)
  222. elif watch_point_id:
  223. self.validate_watchpoint_id(watch_point_id)
  224. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  225. self.put(watchpoint)
  226. self._outdated = True
  227. return new_id
  228. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
  229. """
  230. Update watchpoint.
  231. Args:
  232. watch_point_id (int): The id of watchpoint.
  233. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  234. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  235. If True, add nodes to watch nodes. Default: False.
  236. """
  237. self.validate_watchpoint_id(watch_point_id)
  238. watchpoint = self._watchpoints.get(watch_point_id)
  239. if watched:
  240. watchpoint.add_nodes(watch_nodes)
  241. else:
  242. watchpoint.remove_nodes(watch_nodes)
  243. self._updated_watchpoints[watch_point_id] = watchpoint
  244. self._outdated = True
  245. log.debug("Update watchpoint %d in cache.", watch_point_id)
  246. def delete_watchpoint(self, watch_point_id=None):
  247. """
  248. Delete watchpoint.
  249. Args:
  250. watch_point_id (Union[None, int]): The id of watchpoint.
  251. If None, delete all watchpoints. Default: None.
  252. """
  253. if watch_point_id is None:
  254. watch_point_ids = [sub_id for sub_id, _ in self._watchpoints.items()]
  255. else:
  256. self.validate_watchpoint_id(watch_point_id)
  257. watch_point_ids = [watch_point_id]
  258. for single_id in watch_point_ids:
  259. self._delete_single_watchpoint(single_id)
  260. self._outdated = True
  261. def _delete_single_watchpoint(self, watch_point_id):
  262. """
  263. Delete single watchpoint.
  264. Args:
  265. watch_point_id (int): The id of watchpoint.
  266. """
  267. self._watchpoints.pop(watch_point_id)
  268. # if the watchpoint has not been created by MindSpore, clean the relative cache directly
  269. if watch_point_id in self._created_watchpoints:
  270. self._created_watchpoints.remove(watch_point_id)
  271. self._updated_watchpoints.pop(watch_point_id)
  272. log.debug("Cancel create watchpoint %d in cache.", watch_point_id)
  273. return
  274. set_cmd = SetCMD()
  275. set_cmd.id = watch_point_id
  276. set_cmd.delete = True
  277. self._deleted_watchpoints.append(set_cmd)
  278. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  279. def validate_watchpoint_id(self, watch_point_id):
  280. """Validate watchpoint id."""
  281. if not isinstance(watch_point_id, int):
  282. log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
  283. raise DebuggerParamTypeError("Watchpoint id should be int type.")
  284. if watch_point_id and watch_point_id not in self._watchpoints:
  285. log.error("Invalid watchpoint id: %d.", watch_point_id)
  286. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  287. class WatchpointHitHandler(StreamHandlerBase):
  288. """Watchpoint hit handler."""
  289. def __init__(self):
  290. # dict of <ui node_name, dict of <slot, WatchpointHit>>,
  291. self._hits = {}
  292. @property
  293. def empty(self):
  294. """Whether the watchpoint hit is empty."""
  295. return not self._hits
  296. def put(self, value):
  297. """
  298. Put value into watchpoint hit cache. Called by grpc server.
  299. Args:
  300. value (dict): The watchpoint hit info.
  301. - tensor_proto (TensorProto): The message about hit tensor.
  302. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  303. - node_name (str): The UI node name.
  304. - graph_name (str): The graph name.
  305. """
  306. watchpoint_hit = WatchpointHit(
  307. tensor_proto=value.get('tensor_proto'),
  308. watchpoint=value.get('watchpoint'),
  309. node_name=value.get('node_name'),
  310. graph_name=value.get('graph_name')
  311. )
  312. if 'error_code' in value.keys():
  313. watchpoint_hit.error_code = value.get('error_code')
  314. # get all hit watchpoints according to node name ans tensor slot
  315. watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.node_name,
  316. watchpoint_hit.slot)
  317. if watchpoint_hit not in watchpoint_hits:
  318. watchpoint_hits.append(watchpoint_hit)
  319. def _get_watchpoints_by_tensor_name(self, node_name, slot):
  320. """
  321. Get hit tensors according to ui node name and slot.
  322. Args:
  323. node_name (str): The node name.
  324. slot (str): The tensor slot.
  325. Returns:
  326. list, list of watchpoints.
  327. """
  328. hit_node = self._hits.get(node_name)
  329. if hit_node is None:
  330. hit_node = {}
  331. self._hits[node_name] = hit_node
  332. hit_tensors = hit_node.get(slot)
  333. if hit_tensors is None:
  334. hit_tensors = []
  335. hit_node[slot] = hit_tensors
  336. return hit_tensors
  337. def get(self, filter_condition=None):
  338. """
  339. Get watchpoint hit list.
  340. Args:
  341. filter_condition (str): Get the watchpoint hit according to specified node name.
  342. If not given, get all watchpoint hits. Default: None.
  343. Returns:
  344. dict, the watchpoint hit list.
  345. """
  346. if filter_condition is None:
  347. log.debug("Get all watchpoint hit list.")
  348. reply = self.get_watchpoint_hits()
  349. else:
  350. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  351. reply = self._hits.get(filter_condition)
  352. return reply
  353. def get_watchpoint_hits(self):
  354. """Return the list of watchpoint hits."""
  355. watch_point_hits = []
  356. for node_name, watchpoint_hits in self._hits.items():
  357. tensors = []
  358. graph_name = None
  359. for slot, tensor_hits in watchpoint_hits.items():
  360. if graph_name is None:
  361. graph_name = tensor_hits[0].graph_name
  362. tensor_info = self._get_tensor_hit_info(slot, tensor_hits)
  363. tensors.append(tensor_info)
  364. watch_point_hits.append({
  365. 'node_name': node_name,
  366. 'tensors': tensors,
  367. 'graph_name': graph_name
  368. })
  369. return {'watch_point_hits': watch_point_hits}
  370. @staticmethod
  371. def _get_tensor_hit_info(slot, tensor_hits):
  372. """
  373. Get watchpoint hit info of specified tensor.
  374. Args:
  375. slot (str): Slot id.
  376. tensor_hits (list): A list of watchpoint hit objects that the tensor hit.
  377. Returns:
  378. dict, tensor hit info.
  379. """
  380. res = {}
  381. watch_points = []
  382. for tensor_hit in tensor_hits:
  383. error_code = tensor_hit.error_code
  384. error_list = _get_error_list(error_code)
  385. watchpoint = tensor_hit.watchpoint
  386. watchpoint['error_code'] = error_code
  387. watchpoint['error_list'] = error_list
  388. watch_points.append(watchpoint)
  389. if watch_points:
  390. res = {
  391. 'slot': slot,
  392. 'watch_points': watch_points
  393. }
  394. return res
  395. def _is_tensor_hit(self, tensor_name):
  396. """
  397. Check if the tensor is record in hit cache.
  398. Args:
  399. tensor_name (str): The name of ui tensor name.
  400. Returns:
  401. bool, if the tensor is hit.
  402. """
  403. node_name, slot = tensor_name.rsplit(':', 1)
  404. watchpoint_hits = self._hits.get(node_name, {}).get(slot)
  405. return bool(watchpoint_hits)
  406. def update_tensor_history(self, tensor_history):
  407. """
  408. Add hit flag to tensor history.
  409. Args:
  410. tensor_history (dict): The tensor history.
  411. """
  412. if not self._hits:
  413. return
  414. # add hit tensor names to `tensor_names`
  415. for tensor_info in tensor_history.get('tensor_history'):
  416. tensor_name = tensor_info['name']
  417. hit_flag = self._is_tensor_hit(tensor_name)
  418. tensor_info['is_hit'] = hit_flag
  419. def get_tensor_hit_infos(self, tensor_name):
  420. """
  421. Get all hit information of a tensor.
  422. Args:
  423. tensor_name (str): Tensor name showed on UI.
  424. Returns:
  425. dict, tensor hit info.
  426. """
  427. tensor_hit_info = {}
  428. if self._is_tensor_hit(tensor_name):
  429. node_name, slot = tensor_name.rsplit(':', 1)
  430. tensor_hits = self._get_watchpoints_by_tensor_name(node_name, slot)
  431. tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits)
  432. return tensor_hit_info
  433. def validate_watch_condition(condition_mgr, watch_condition):
  434. """Validate watch condition."""
  435. if not isinstance(watch_condition, dict):
  436. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  437. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  438. # validate condition_id
  439. condition_id = watch_condition.get('id')
  440. if condition_id not in condition_mgr.conditions.keys():
  441. log.error("Invalid watch condition. Acceptable values are <%s>. %s received.",
  442. str(condition_mgr.conditions.keys()), condition_id)
  443. raise DebuggerParamValueError("Invalid watch condition value.")
  444. # validate param
  445. validate_watch_condition_params(condition_mgr, watch_condition)
  446. def validate_watch_condition_params(condition_mgr, watch_condition):
  447. """
  448. Validate watch condition parameters.
  449. Args:
  450. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  451. watch_condition (dict): Watch condition.
  452. - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.
  453. - param (list): Condition value. Should be given for comparison condition. The value
  454. will be translated to np.float32.
  455. """
  456. condition_id = watch_condition.get('id')
  457. params = watch_condition.get('params')
  458. condition = condition_mgr.get_condition(condition_id)
  459. if condition_id in condition_mgr.get_no_param_condition():
  460. if params:
  461. log.error("No param is expected for %s condition", condition_id)
  462. raise DebuggerParamValueError("No param is expected.")
  463. return
  464. check_param_num = 0
  465. support_params = set()
  466. defined_support_params = set()
  467. for param in params:
  468. if len(param) > 2:
  469. log.error("Invalid param keys for condition: %s", condition_id)
  470. raise DebuggerParamValueError("Invalid param keys.")
  471. condition_param_name = param.get("name")
  472. if condition_param_name not in condition.names:
  473. log.error("Invalid name of parameter for condition: %s, available values: %s",
  474. condition_id, condition.names)
  475. raise DebuggerParamValueError("Invalid name of parameter.")
  476. condition_param = condition.get_parameter_definition(condition_param_name)
  477. if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
  478. and not isinstance(param.get("value"), (float, int)):
  479. log.error("Number param should be given for condition: %s", condition_id)
  480. raise DebuggerParamValueError("Number param should be given.")
  481. if condition_param.type.name == ValueTypeEnum.BOOL.name \
  482. and not isinstance(param.get("value"), bool):
  483. log.error("Bool param should be given for condition: %s", condition_id)
  484. raise DebuggerParamValueError("Bool param should be given.")
  485. if not condition_param.is_valid(param.get("value")):
  486. log.error("Param %s out of range for condition: %s", condition_param_name, condition_id)
  487. raise DebuggerParamValueError("Parameter out of range.")
  488. if condition_param.param_type == ParamTypeEnum.CHECK_PARAM.value:
  489. if condition_param.required_params:
  490. defined_support_params = set(condition_param.required_params)
  491. check_param_num += 1
  492. else:
  493. support_params.add(condition_param.name)
  494. if check_param_num > 1:
  495. log.error("Multiple check params for condition: %s", condition_id)
  496. raise DebuggerParamValueError("Multiple check params.")
  497. if support_params != defined_support_params:
  498. log.error("Invalid support params for condition: %s", condition_id)
  499. raise DebuggerParamValueError("Invalid support params.")
  500. def set_default_param(condition_mgr, watch_condition):
  501. """
  502. Set default param.
  503. Args:
  504. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  505. watch_condition (dict): The watch condition.
  506. "condition": {
  507. id: "tensor_too_large",
  508. "params": [
  509. {
  510. "name": "abs_mean_gt",
  511. "value": 1.1
  512. }
  513. ]
  514. }
  515. - id (str): Id of condition.
  516. - param (list[dict]): The list of param for this condition.
  517. Returns:
  518. dict, the new watch_condition.
  519. """
  520. condition_id = watch_condition.get('id')
  521. condition = condition_mgr.get_condition(condition_id)
  522. for param in condition.parameters:
  523. if not param.visible_on_ui and not param.support_disable:
  524. watch_condition["params"].append({
  525. "name": param.name,
  526. "value": param.default_value
  527. })
  528. watch_condition["abbr"] = condition.abbr
  529. return watch_condition
  530. def _get_error_list(error_code):
  531. """
  532. Get error list.
  533. Args:
  534. error_code (int): the code of errors.
  535. Returns:
  536. list, the error list.
  537. """
  538. all_error_list = ["nan", "inf", "no_prev_tensor"]
  539. error_list = []
  540. for i, error_str in enumerate(all_error_list):
  541. error = (error_code >> i) & 1
  542. if error == 1:
  543. error_list.append(error_str)
  544. return error_list