You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recommender.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Predefined watchpoints.
  17. This module predefine recommend watchpoints.
  18. """
  19. import queue as Queue
  20. from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
  21. from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum
  22. from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
  23. from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
  24. from mindinsight.debugger.conditionmgr.log import logger
  25. from mindinsight.conf import settings
  26. UNSELECTED_STATUS = 0
  27. HALF_SELECTED_STATUS = 1
  28. SELECTED_STATUS = 2
  29. class _WatchPointData:
  30. """WatchPoint data container"""
  31. def __init__(self, watch_condition, watch_nodes):
  32. self.watch_condition = watch_condition
  33. self.watch_nodes = watch_nodes
  34. def get_watch_condition_dict(self):
  35. return {
  36. "id": self.watch_condition.get("condition"),
  37. "params": [{
  38. "name": param.get_parameter_name(),
  39. "value": param.value
  40. } for param in self.watch_condition.get("params")]
  41. }
  42. class _ConditionParameterValue:
  43. """Condition parameter data container"""
  44. def __init__(self, parameter, value):
  45. self.parameter = parameter
  46. self.value = value
  47. def get_parameter_name(self):
  48. return self.parameter.name
  49. def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_context):
  50. """
  51. Recommend watchpoints.
  52. Args:
  53. condition_mgr (ConditionMgr): Condition manager instance.
  54. graph_stream (GraphHandler): Graph handler instance.
  55. condition_context (ConditionContext): Context for condition.
  56. Returns:
  57. list[WatchPointData], watch points to be created.
  58. """
  59. watch_points = []
  60. if not graph_stream.graph:
  61. logger.warning("Given graph is None.")
  62. return watch_points
  63. if not settings.ENABLE_RECOMMENDED_WATCHPOINTS:
  64. return watch_points
  65. # add weight watch points
  66. merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, graph_stream)
  67. _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context)
  68. _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context)
  69. # Because we cannot identify trainable weights currently, weight_no_change and weight_change_too_small will not be
  70. # recommended.
  71. trainable_weight_nodes = []
  72. _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_points, condition_context)
  73. _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context)
  74. # add gradient watch points
  75. merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, graph_stream)
  76. _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context)
  77. # add tensor watch points
  78. merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, graph_stream)
  79. _recommend_overflow_ascend_chip(merged_info, condition_mgr, watch_points, condition_context)
  80. _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context)
  81. _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context)
  82. return watch_points
  83. def _recommend_tensor_all_zero(basic_info_nodes, condition_mgr, watch_points, condition_context):
  84. """Recommend tensor all zero watchpoint."""
  85. if not basic_info_nodes:
  86. return
  87. if not condition_mgr.has_condition(ConditionIdEnum.TENSOR_ALL_ZERO.value, condition_context):
  88. return
  89. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.TENSOR_ALL_ZERO.value)
  90. tensor_all_zero_watchpoint = _WatchPointData(
  91. watch_condition={
  92. "condition": condition.id,
  93. "params": [_ConditionParameterValue(
  94. parameter=condition.get_parameter_definition("zero_percentage_ge"),
  95. value=100 # set default value to 100
  96. )]
  97. },
  98. watch_nodes=basic_info_nodes.copy(),
  99. )
  100. watch_points.append(tensor_all_zero_watchpoint)
  101. def _recommend_tensor_overflow(basic_info_nodes, condition_mgr, watch_points, condition_context):
  102. """Recommend tensor general overflow watchpoint."""
  103. if not basic_info_nodes:
  104. return
  105. if not condition_mgr.has_condition(ConditionIdEnum.TENSOR_OVERFLOW.value, condition_context):
  106. return
  107. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.TENSOR_OVERFLOW.value)
  108. overflow_watchpoint = _WatchPointData(
  109. watch_condition={
  110. "condition": condition.id,
  111. "params": []
  112. },
  113. watch_nodes=basic_info_nodes.copy(),
  114. )
  115. watch_points.append(overflow_watchpoint)
  116. def _recommend_overflow_ascend_chip(basic_info_nodes, condition_mgr, watch_points, condition_context):
  117. """Recommend tensor overflow watchpoint."""
  118. if not basic_info_nodes:
  119. return
  120. if not condition_mgr.has_condition(ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value, condition_context):
  121. return
  122. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value)
  123. overflow_d_watchpoint = _WatchPointData(
  124. watch_condition={
  125. "condition": condition.id,
  126. "params": []
  127. },
  128. watch_nodes=basic_info_nodes.copy(),
  129. )
  130. watch_points.append(overflow_d_watchpoint)
  131. def _recommend_gradient_vanishing(basic_info_nodes, condition_mgr, watch_points, condition_context):
  132. """Recommend gradient vanishing watchpoint."""
  133. if not basic_info_nodes:
  134. return
  135. if not condition_mgr.has_condition(ConditionIdEnum.GRADIENT_VANISHING.value, condition_context):
  136. return
  137. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.GRADIENT_VANISHING.value)
  138. gradient_vanishing_watchpoint = _WatchPointData(
  139. watch_condition={
  140. "condition": condition.id,
  141. "params": [_ConditionParameterValue(
  142. parameter=condition.get_parameter_definition("abs_mean_lt"),
  143. value=1e-9 # set default value to 1e-9
  144. )]
  145. },
  146. watch_nodes=basic_info_nodes.copy(),
  147. )
  148. watch_points.append(gradient_vanishing_watchpoint)
  149. def _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context):
  150. """Recommend weight change too small watchpoint."""
  151. if not trainable_weight_nodes:
  152. return
  153. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value, condition_context):
  154. return
  155. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value)
  156. weight_change_too_small_watchpoint = _WatchPointData(
  157. watch_condition={
  158. "condition": condition.id,
  159. "params": [
  160. _ConditionParameterValue(
  161. parameter=condition.get_parameter_definition("abs_update_ratio_mean_lt"),
  162. value=1.0e-4 # set default value to 1.0e-4
  163. ),
  164. ]
  165. },
  166. watch_nodes=trainable_weight_nodes,
  167. )
  168. watch_points.append(weight_change_too_small_watchpoint)
  169. def _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_points, condition_context):
  170. """Recommend weight not changed watchpoint."""
  171. if not trainable_weight_nodes:
  172. return
  173. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_NOT_CHANGED.value, condition_context):
  174. return
  175. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_NOT_CHANGED.value)
  176. weight_no_change_watchpoint = _WatchPointData(
  177. watch_condition={
  178. "condition": condition.id,
  179. "params": [
  180. _ConditionParameterValue(
  181. parameter=condition.get_parameter_definition("rtol"),
  182. value=1.0e-5 # set default value to 1.0e-5
  183. ),
  184. _ConditionParameterValue(
  185. parameter=condition.get_parameter_definition("atol"),
  186. value=1.0e-8 # set default value to 1.0e-8
  187. ),
  188. ]
  189. },
  190. watch_nodes=trainable_weight_nodes,
  191. )
  192. watch_points.append(weight_no_change_watchpoint)
  193. def _recommend_weight_change_too_large(basic_info_nodes, condition_mgr, watch_points, condition_context):
  194. """Recommend weight change too large watchpoint."""
  195. if not basic_info_nodes:
  196. return
  197. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value, condition_context):
  198. return
  199. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value)
  200. weight_initialization_watchpoint = _WatchPointData(
  201. watch_condition={
  202. "condition": condition.id,
  203. "params": [_ConditionParameterValue(
  204. parameter=condition.get_parameter_definition("abs_update_ratio_mean_gt"),
  205. value=0.1 # set default value to 0.1
  206. )]
  207. },
  208. watch_nodes=basic_info_nodes.copy(),
  209. )
  210. watch_points.append(weight_initialization_watchpoint)
  211. def _recommend_weight_initialization(basic_info_nodes, condition_mgr, watch_points, condition_context):
  212. """Recommend weight initialization watchpoint."""
  213. if not basic_info_nodes:
  214. return
  215. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_INITIALIZATION.value, condition_context):
  216. return
  217. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_INITIALIZATION.value)
  218. weight_initialization_watchpoint = _WatchPointData(
  219. watch_condition={
  220. "condition": condition.id,
  221. "params": [_ConditionParameterValue(
  222. parameter=condition.get_parameter_definition("zero_percentage_ge"),
  223. value=100 # set default value to 100
  224. )]
  225. },
  226. watch_nodes=basic_info_nodes.copy(),
  227. )
  228. watch_points.append(weight_initialization_watchpoint)
  229. def get_basic_node_info(node_category, graph_stream):
  230. """Get node merged info."""
  231. basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream)
  232. merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph)
  233. merged_info = _add_graph_name(merged_info, graph_stream)
  234. return merged_info
  235. def _get_basic_node_info_by_node_category(node_category, graph_stream):
  236. """Get node basic info by node category."""
  237. all_graph_nodes = graph_stream.get_searched_nodes(pattern={'node_category': node_category})
  238. basic_info_nodes = []
  239. for graph_name, nodes in all_graph_nodes.items():
  240. if len(all_graph_nodes) == 1:
  241. logger.debug("This is a single graph")
  242. graph_name = ""
  243. for node in nodes:
  244. if graph_name == "":
  245. basic_node_info = NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
  246. else:
  247. basic_node_info = graph_stream.construct_node_basic_info(
  248. full_name=node.full_name, graph_name=graph_name, node_name=node.name, node_type=node.type)
  249. basic_info_nodes.append(basic_node_info)
  250. return basic_info_nodes
  251. def _merge_nodes(leaf_nodes, graph):
  252. """merge nodes in one graph"""
  253. unmerged_tree = graph.get_nodes(leaf_nodes)
  254. tmp_node_queue = Queue.Queue()
  255. # watch node list in layer order
  256. watch_nodes = []
  257. for node in unmerged_tree:
  258. if node["type"] != "name_scope":
  259. # if node is leaf_node, it is totally chosen
  260. node["status"] = SELECTED_STATUS
  261. else:
  262. # if node is not leaf_node, it is not chosen initially
  263. node["status"] = UNSELECTED_STATUS
  264. tmp_node_queue.put(node)
  265. while not tmp_node_queue.empty():
  266. cur_node = tmp_node_queue.get()
  267. watch_nodes.append(cur_node)
  268. for sub_node in cur_node["nodes"]:
  269. if sub_node["type"] != "name_scope":
  270. # if node is leaf_node, it is totally chosen
  271. sub_node["status"] = SELECTED_STATUS
  272. else:
  273. # if node is not leaf_node, it is not chosen initially
  274. sub_node["status"] = UNSELECTED_STATUS
  275. tmp_node_queue.put(sub_node)
  276. merged_watch_nodes = []
  277. while watch_nodes:
  278. cur_node = watch_nodes.pop()
  279. node_name = cur_node["name"]
  280. sub_count = graph.normal_node_map.get(node_name).subnode_count
  281. if len(cur_node["nodes"]) < sub_count or not cur_node["nodes"]:
  282. continue
  283. is_all_chosen = True
  284. for sub_node in cur_node["nodes"]:
  285. if sub_node["status"] != SELECTED_STATUS:
  286. is_all_chosen = False
  287. break
  288. if is_all_chosen:
  289. cur_node["status"] = SELECTED_STATUS
  290. merged_watch_nodes.append(cur_node)
  291. else:
  292. cur_node["status"] = HALF_SELECTED_STATUS
  293. logger.debug("merged_watch_nodes: %s", merged_watch_nodes)
  294. out_nodes = []
  295. for node_info in merged_watch_nodes:
  296. node_basic_info = NodeBasicInfo(name=node_info["name"], full_name=node_info["name"], type=node_info["type"])
  297. out_nodes.append(node_basic_info)
  298. logger.debug("out_nodes: %s", out_nodes)
  299. return out_nodes
  300. def _add_graph_name(nodes, graph_stream):
  301. """add graph_name in node.name"""
  302. if len(graph_stream.graph) > 1:
  303. return nodes
  304. graph_name = graph_stream.graph_names[0]
  305. output_nodes = []
  306. for node in nodes:
  307. node_basic_info = graph_stream.construct_node_basic_info(
  308. full_name=node.name, graph_name=graph_name, node_name=node.name, node_type=node.type)
  309. output_nodes.append(node_basic_info)
  310. return output_nodes