You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 25 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. from mindinsight.debugger.conditionmgr.condition import ValueTypeEnum
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.debugger.common.utils import is_scope_type
  21. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  22. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  23. WatchNodeTree
  24. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  25. class WatchpointHandler(StreamHandlerBase):
  26. """Watchpoint Handler."""
  27. def __init__(self):
  28. self._watchpoints = {}
  29. # list of ids of new created watchpoints
  30. self._created_watchpoints = []
  31. # list of SetCMD of watchpoints to be deleted
  32. self._deleted_watchpoints = []
  33. # dict of <id, SetCMD> of watchpoint to be updated
  34. self._updated_watchpoints = {}
  35. # the collection of watched node full names, which have been sent to MindSpore
  36. self._all_watched_node_full_names = set()
  37. # the collection of new watched node full names, which have not been sent to MindSpore
  38. self._new_watched_node_full_names = set()
  39. # record the temp stored nodes in MS, which could be set as watch node for recheck on GPU
  40. # should be clean at the beginning of each step
  41. self._temp_cached_node_full_names = set()
  42. self._latest_id = 0
  43. self._cache_set_cmd = {}
  44. # whether the watchpoint list has been changed since last step
  45. self.outdated = False
  46. def put(self, value):
  47. """
  48. Put Watchpoint into watchpoint handler.
  49. Args:
  50. value (Watchpoint): The name of nodes that have been chosen.
  51. """
  52. new_id = value.watchpoint_id
  53. self._watchpoints[new_id] = value
  54. self._created_watchpoints.append(new_id)
  55. self._updated_watchpoints[new_id] = value
  56. self._latest_id = new_id
  57. log.debug("Put watchpoint %d into cache.", new_id)
  58. def clean_temp_cached_names(self):
  59. """Clean temp cached node."""
  60. self._temp_cached_node_full_names.clear()
  61. def add_temp_cached_name(self, node_full_name):
  62. """Add temp stored node in cache."""
  63. if node_full_name:
  64. self._temp_cached_node_full_names.add(node_full_name)
  65. def sync_set_cmd(self, set_cmds):
  66. """Clean temp watchpoints."""
  67. self._new_watched_node_full_names = set()
  68. self._created_watchpoints = []
  69. self._deleted_watchpoints = []
  70. self._updated_watchpoints = {}
  71. for set_cmd in set_cmds:
  72. self._cache_set_cmd[set_cmd.id] = set_cmd
  73. def clean_cache_set_cmd(self, set_cmd):
  74. """Clean cache set command."""
  75. self._cache_set_cmd.pop(set_cmd.id, None)
  76. def get_watchpoint_by_id(self, watchpoint_id):
  77. """Get watchpoint by watchpoint id."""
  78. res = self.get(watchpoint_id)
  79. watchpoint = res.get('watch_points')[0]
  80. return watchpoint
  81. def get(self, filter_condition=None):
  82. """
  83. Get the watchpoints.
  84. Args:
  85. filter_condition (Union[None, int]): The filter conditions. Get watchpoint by
  86. id. If None, return all watchpoint. Default: None.
  87. Returns:
  88. dict, the watchpoint list.
  89. """
  90. reply = []
  91. if not filter_condition:
  92. # get watch condition list
  93. for _, watchpoint in self._watchpoints.items():
  94. watchpoint_info = watchpoint.get_watch_condition_info()
  95. reply.append(watchpoint_info)
  96. else:
  97. self.validate_watchpoint_id(filter_condition)
  98. reply = [self._watchpoints.get(filter_condition)]
  99. log.debug("get the watch points with filter_condition:%s", filter_condition)
  100. return {'watch_points': reply}
  101. def get_pending_commands(self, graph_stream):
  102. """
  103. Get all watchpoint in SetCMD proto format.
  104. Args:
  105. graph_stream (GraphHandler): Graph handler.
  106. Returns:
  107. list[SetCMD], updated watchpoint to be sent to MindSpore.
  108. """
  109. res = []
  110. new_watched_nodes = set()
  111. self._all_watched_node_full_names.clear()
  112. for _, watchpoint in self._updated_watchpoints.items():
  113. # construct set command with leaf nodes
  114. watch_nodes = watchpoint.get_watch_nodes()
  115. leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
  116. res.append(watchpoint.get_pending_cmd(leaf_watch_nodes))
  117. # update all watched node names
  118. watch_node_names = [watch_node.full_name for watch_node in [*watch_nodes, *leaf_watch_nodes]]
  119. new_watched_nodes.update(watch_node_names)
  120. res.extend(self._deleted_watchpoints)
  121. for _, set_cmd in self._cache_set_cmd.items():
  122. res.append(set_cmd)
  123. self._all_watched_node_full_names = new_watched_nodes
  124. return res
  125. @staticmethod
  126. def _expand_to_leaf_nodes(graph_stream, watch_nodes):
  127. """
  128. Get all leaf node basic info according to watch nodes.
  129. Args:
  130. graph_stream (GraphHandler): Graph handler.
  131. watch_nodes (list[NodeBasicInfo]): The list of watch node basic infos.
  132. Returns:
  133. list[NodeBasicInfo], expanded leaf basic node infos.
  134. """
  135. leaf_watch_nodes = []
  136. for node in watch_nodes:
  137. if is_scope_type(node.type):
  138. pure_node_name = None
  139. if len(node.name.split('/')) > 1:
  140. graph_name, pure_node_name = node.name.split('/', 1)
  141. else:
  142. graph_name = node.name
  143. search_node_infos = graph_stream.get_node_basic_info_by_scope(pure_node_name, graph_name=graph_name)
  144. leaf_watch_nodes.extend(search_node_infos)
  145. else:
  146. leaf_watch_nodes.append(node)
  147. return leaf_watch_nodes
  148. def is_recheckable(self, backend=None):
  149. """
  150. Check if current status is able to recheck.
  151. Args:
  152. backend (str): The backend info. 'Ascend' or 'GPU'. Default: None.
  153. Returns:
  154. bool, if enable to recheck.
  155. """
  156. enable_recheck = self.outdated
  157. if backend == 'GPU' and enable_recheck:
  158. # on GPU, disable to recheck if there are new watched node of which the tensor
  159. # has not been stored on MindSpore
  160. diff_set = self._new_watched_node_full_names - self._all_watched_node_full_names
  161. enable_recheck = not diff_set or diff_set.issubset(self._temp_cached_node_full_names)
  162. return enable_recheck
  163. def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None):
  164. """
  165. set watch nodes for graph.
  166. Args:
  167. graph (dict): The graph with list of nodes.
  168. graph_stream (GraphHandler): The graph handler.
  169. watch_point_id (int): The id of watchpoint.
  170. graph_name (str): The graph name.
  171. """
  172. if not (watch_point_id and graph):
  173. return
  174. log.debug("add watch flags")
  175. watchpoint = self._watchpoints.get(watch_point_id)
  176. self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name)
  177. def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None):
  178. """Set watch status to graph."""
  179. if graph.get('children'):
  180. self._set_watch_status_recursively(
  181. graph.get('children'), graph_stream, watchpoint, graph_name)
  182. if graph.get('nodes'):
  183. _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name)
  184. def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name):
  185. """
  186. Set watch state for nodes.
  187. Args:
  188. nodes (list[Node]): List of node info.
  189. Returns:
  190. int, the number of all watched nodes.
  191. """
  192. all_watched_num = 0
  193. for node in nodes:
  194. node_name = node.get('name')
  195. # search result could have `nodes` in nodes object
  196. if node.get('nodes'):
  197. flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name)
  198. else:
  199. full_name = graph_stream.get_full_name(node_name, graph_name)
  200. new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name])
  201. flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name)
  202. node['watched'] = flag
  203. if flag == WatchNodeTree.TOTAL_WATCH:
  204. all_watched_num += 1
  205. # calculate the state of current node.
  206. if not all_watched_num:
  207. state = WatchNodeTree.NOT_WATCH
  208. elif all_watched_num == len(nodes):
  209. state = WatchNodeTree.TOTAL_WATCH
  210. else:
  211. state = WatchNodeTree.PARTIAL_WATCH
  212. return state
  213. def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None):
  214. """
  215. Create watchpoint.
  216. Args:
  217. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  218. watch_condition (dict): The watch condition.
  219. "condition": {
  220. id: "tensor_too_large",
  221. "params": [
  222. {
  223. "name": "abs_mean_gt",
  224. "value": 1.1
  225. }
  226. ]
  227. }
  228. - id (str): Id of condition.
  229. - param (list[dict]): The list of param for this condition.
  230. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  231. watch_point_id (int): The id of watchpoint.
  232. Returns:
  233. int, the new id of watchpoint.
  234. """
  235. validate_watch_condition(condition_mgr, watch_condition)
  236. watch_condition = set_default_param(condition_mgr, watch_condition)
  237. new_id = self._latest_id + 1
  238. watchpoint = Watchpoint(new_id, watch_condition)
  239. if watch_nodes:
  240. watchpoint.add_nodes(watch_nodes)
  241. self._add_watch_node_in_cache(watch_nodes)
  242. elif watch_point_id:
  243. self.validate_watchpoint_id(watch_point_id)
  244. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  245. self.put(watchpoint)
  246. self.outdated = True
  247. return new_id
  248. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
  249. """
  250. Update watchpoint.
  251. Args:
  252. watch_point_id (int): The id of watchpoint.
  253. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  254. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  255. If True, add nodes to watch nodes. Default: False.
  256. """
  257. self.validate_watchpoint_id(watch_point_id)
  258. watchpoint = self._watchpoints.get(watch_point_id)
  259. if watched:
  260. watchpoint.add_nodes(watch_nodes)
  261. self._add_watch_node_in_cache(watch_nodes)
  262. else:
  263. watchpoint.remove_nodes(watch_nodes)
  264. self._remove_watch_node_from_cache(watch_nodes)
  265. self._updated_watchpoints[watch_point_id] = watchpoint
  266. self.outdated = True
  267. log.debug("Update watchpoint %d in cache.", watch_point_id)
  268. def delete_watchpoint(self, watch_point_id=None):
  269. """
  270. Delete watchpoint.
  271. Args:
  272. watch_point_id (Union[None, int]): The id of watchpoint.
  273. If None, delete all watchpoints. Default: None.
  274. """
  275. if watch_point_id is None:
  276. watch_point_ids = [sub_id for sub_id, _ in self._watchpoints.items()]
  277. else:
  278. self.validate_watchpoint_id(watch_point_id)
  279. watch_point_ids = [watch_point_id]
  280. for single_id in watch_point_ids:
  281. self._delete_single_watchpoint(single_id)
  282. self.outdated = True
  283. def _delete_single_watchpoint(self, watch_point_id):
  284. """
  285. Delete single watchpoint.
  286. Args:
  287. watch_point_id (int): The id of watchpoint.
  288. """
  289. self._watchpoints.pop(watch_point_id)
  290. # if the watchpoint has not been created by MindSpore, clean the relative cache directly
  291. if watch_point_id in self._created_watchpoints:
  292. self._created_watchpoints.remove(watch_point_id)
  293. self._updated_watchpoints.pop(watch_point_id)
  294. log.debug("Cancel create watchpoint %d in cache.", watch_point_id)
  295. return
  296. set_cmd = SetCMD()
  297. set_cmd.id = watch_point_id
  298. set_cmd.delete = True
  299. self._deleted_watchpoints.append(set_cmd)
  300. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  301. def validate_watchpoint_id(self, watch_point_id):
  302. """Validate watchpoint id."""
  303. if not isinstance(watch_point_id, int):
  304. log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
  305. raise DebuggerParamTypeError("Watchpoint id should be int type.")
  306. if watch_point_id and watch_point_id not in self._watchpoints:
  307. log.error("Invalid watchpoint id: %d.", watch_point_id)
  308. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  309. def _add_watch_node_in_cache(self, watch_nodes):
  310. """
  311. Add watch nodes in cache.
  312. Args:
  313. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  314. """
  315. node_full_names = [node.full_name for node in watch_nodes]
  316. self._new_watched_node_full_names.update(node_full_names)
  317. def _remove_watch_node_from_cache(self, watch_nodes):
  318. """
  319. Remove watch nodes from cache.
  320. Args:
  321. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  322. """
  323. for node in watch_nodes:
  324. if node.full_name in self._new_watched_node_full_names:
  325. self._new_watched_node_full_names.remove(node.full_name)
  326. class WatchpointHitHandler(StreamHandlerBase):
  327. """Watchpoint hit handler."""
  328. def __init__(self):
  329. # dict of <ui node_name, dict of <slot, WatchpointHit>>,
  330. self._hits = {}
  331. @property
  332. def empty(self):
  333. """Whether the watchpoint hit is empty."""
  334. return not self._hits
  335. def put(self, value):
  336. """
  337. Put value into watchpoint hit cache. Called by grpc server.
  338. Args:
  339. value (dict): The watchpoint hit info.
  340. - tensor_proto (TensorProto): The message about hit tensor.
  341. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  342. - node_name (str): The UI node name.
  343. - graph_name (str): The graph name.
  344. """
  345. watchpoint_hit = WatchpointHit(
  346. tensor_proto=value.get('tensor_proto'),
  347. watchpoint=value.get('watchpoint'),
  348. node_name=value.get('node_name'),
  349. graph_name=value.get('graph_name')
  350. )
  351. if 'error_code' in value.keys():
  352. watchpoint_hit.error_code = value.get('error_code')
  353. # get all hit watchpoints according to node name ans tensor slot
  354. watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.node_name,
  355. watchpoint_hit.slot)
  356. if watchpoint_hit not in watchpoint_hits:
  357. watchpoint_hits.append(watchpoint_hit)
  358. def _get_watchpoints_by_tensor_name(self, node_name, slot):
  359. """
  360. Get hit tensors according to ui node name and slot.
  361. Args:
  362. node_name (str): The node name.
  363. slot (str): The tensor slot.
  364. Returns:
  365. list, list of watchpoints.
  366. """
  367. hit_node = self._hits.get(node_name)
  368. if hit_node is None:
  369. hit_node = {}
  370. self._hits[node_name] = hit_node
  371. hit_tensors = hit_node.get(slot)
  372. if hit_tensors is None:
  373. hit_tensors = []
  374. hit_node[slot] = hit_tensors
  375. return hit_tensors
  376. def get(self, filter_condition=None):
  377. """
  378. Get watchpoint hit list.
  379. Args:
  380. filter_condition (str): Get the watchpoint hit according to specified node name.
  381. If not given, get all watchpoint hits. Default: None.
  382. Returns:
  383. dict, the watchpoint hit list.
  384. """
  385. if filter_condition is None:
  386. log.debug("Get all watchpoint hit list.")
  387. reply = self.get_watchpoint_hits()
  388. else:
  389. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  390. reply = self._hits.get(filter_condition)
  391. return reply
  392. def get_watchpoint_hits(self):
  393. """Return the list of watchpoint hits."""
  394. watch_point_hits = []
  395. for node_name, watchpoint_hits in self._hits.items():
  396. tensors = []
  397. graph_name = None
  398. for slot, tensor_hits in watchpoint_hits.items():
  399. if graph_name is None:
  400. graph_name = tensor_hits[0].graph_name
  401. tensor_info = self._get_tensor_hit_info(slot, tensor_hits)
  402. tensors.append(tensor_info)
  403. watch_point_hits.append({
  404. 'node_name': node_name,
  405. 'tensors': tensors,
  406. 'graph_name': graph_name
  407. })
  408. return {'watch_point_hits': watch_point_hits}
  409. @staticmethod
  410. def _get_tensor_hit_info(slot, tensor_hits):
  411. """
  412. Get watchpoint hit info of specified tensor.
  413. Args:
  414. slot (str): Slot id.
  415. tensor_hits (list): A list of watchpoint hit objects that the tensor hit.
  416. Returns:
  417. dict, tensor hit info.
  418. """
  419. res = {}
  420. watch_points = []
  421. error_codes = set()
  422. for tensor_hit in tensor_hits:
  423. error_code = tensor_hit.error_code
  424. watchpoint = tensor_hit.watchpoint
  425. watchpoint['error_code'] = error_code
  426. watch_points.append(watchpoint)
  427. error_codes.add(error_code)
  428. summarized_error_code = error_codes.pop()
  429. while error_codes:
  430. temp = error_codes.pop()
  431. summarized_error_code = summarized_error_code | temp
  432. if watch_points:
  433. res = {
  434. 'slot': slot,
  435. 'summarized_error_code': summarized_error_code,
  436. 'watch_points': watch_points
  437. }
  438. return res
  439. def _is_tensor_hit(self, tensor_name):
  440. """
  441. Check if the tensor is record in hit cache.
  442. Args:
  443. tensor_name (str): The name of ui tensor name.
  444. Returns:
  445. bool, if the tensor is hit.
  446. """
  447. node_name, slot = tensor_name.rsplit(':', 1)
  448. watchpoint_hits = self._hits.get(node_name, {}).get(slot)
  449. return bool(watchpoint_hits)
  450. def update_tensor_history(self, tensor_history):
  451. """
  452. Add hit flag to tensor history.
  453. Args:
  454. tensor_history (dict): The tensor history.
  455. """
  456. if not self._hits:
  457. return
  458. # add hit tensor names to `tensor_names`
  459. for tensor_info in tensor_history.get('tensor_history'):
  460. tensor_name = tensor_info['name']
  461. hit_flag = self._is_tensor_hit(tensor_name)
  462. tensor_info['is_hit'] = hit_flag
  463. def get_tensor_hit_infos(self, tensor_name):
  464. """
  465. Get all hit information of a tensor.
  466. Args:
  467. tensor_name (str): Tensor name showed on UI.
  468. Returns:
  469. dict, tensor hit info.
  470. """
  471. tensor_hit_info = {}
  472. if self._is_tensor_hit(tensor_name):
  473. node_name, slot = tensor_name.rsplit(':', 1)
  474. tensor_hits = self._get_watchpoints_by_tensor_name(node_name, slot)
  475. tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits)
  476. return tensor_hit_info
  477. def validate_watch_condition(condition_mgr, watch_condition):
  478. """Validate watch condition."""
  479. if not isinstance(watch_condition, dict):
  480. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  481. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  482. # validate condition_id
  483. condition_id = watch_condition.get('id')
  484. if condition_id not in condition_mgr.conditions.keys():
  485. log.error("Invalid watch condition. Acceptable values are <%s>. %s received.",
  486. str(condition_mgr.conditions.keys()), condition_id)
  487. raise DebuggerParamValueError("Invalid watch condition value.")
  488. # validate param
  489. validate_watch_condition_params(condition_mgr, watch_condition)
  490. def validate_watch_condition_params(condition_mgr, watch_condition):
  491. """
  492. Validate watch condition parameters.
  493. Args:
  494. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  495. watch_condition (dict): Watch condition.
  496. - id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.
  497. - param (list): Condition value. Should be given for comparison condition. The value
  498. will be translated to np.float32.
  499. """
  500. condition_id = watch_condition.get('id')
  501. params = watch_condition.get('params')
  502. condition = condition_mgr.get_condition(condition_id)
  503. if condition_id in condition_mgr.get_no_param_condition():
  504. if params:
  505. log.error("No param is expected for %s condition", condition_id)
  506. raise DebuggerParamValueError("No param is expected.")
  507. return
  508. for param in params:
  509. condition_param_name = param.get("name")
  510. if condition_param_name not in condition.names:
  511. log.error("Invalid name of parameter for condition: %s, available values: %s",
  512. condition_id, condition.names)
  513. raise DebuggerParamValueError("Invalid name of parameter.")
  514. condition_param = condition.get_parameter_definition(condition_param_name)
  515. if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
  516. and not isinstance(param.get("value"), (float, int)):
  517. log.error("Number param should be given for condition: %s", condition_id)
  518. raise DebuggerParamValueError("Number param should be given.")
  519. if condition_param.type.name == ValueTypeEnum.BOOL.name \
  520. and not isinstance(param.get("value"), bool):
  521. log.error("Bool param should be given for condition: %s", condition_id)
  522. raise DebuggerParamValueError("Bool param should be given.")
  523. if not condition_param.is_valid(param.get("value")):
  524. log.error("Param %s out of range for condition: %s", condition_param_name, condition_id)
  525. raise DebuggerParamValueError("Parameter out of range.")
  526. def set_default_param(condition_mgr, watch_condition):
  527. """
  528. Set default param.
  529. Args:
  530. condition_mgr (ConditionMgr): Instance of ConditionMgr.
  531. watch_condition (dict): The watch condition.
  532. "condition": {
  533. id: "tensor_too_large",
  534. "params": [
  535. {
  536. "name": "abs_mean_gt",
  537. "value": 1.1
  538. }
  539. ]
  540. }
  541. - id (str): Id of condition.
  542. - param (list[dict]): The list of param for this condition.
  543. Returns:
  544. dict, the new watch_condition.
  545. """
  546. condition_id = watch_condition.get('id')
  547. condition = condition_mgr.get_condition(condition_id)
  548. for param in condition.parameters:
  549. if not param.visible_on_ui and not param.support_disable:
  550. watch_condition["params"].append({
  551. "name": param.name,
  552. "value": param.default_value
  553. })
  554. watch_condition["abbr"] = condition.abbr
  555. return watch_condition