You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recommender.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Predefined watchpoints.
  17. This module predefine recommend watchpoints.
  18. """
  19. import math
  20. import queue as Queue
  21. from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
  22. from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum
  23. from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
  24. from mindinsight.debugger.conditionmgr.condition import ActivationFuncEnum
  25. from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
  26. from mindinsight.debugger.conditionmgr.log import logger
  27. from mindinsight.conf import settings
  28. from mindinsight.debugger.stream_cache.watchpoint import WatchNodeTree
  29. class _WatchPointData:
  30. """
  31. WatchPoint data container
  32. Args:
  33. watch_condition (dict): The dict of watch conditions.
  34. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  35. name (str): The name of watchpoint.
  36. """
  37. def __init__(self, watch_condition, watch_nodes, name):
  38. self.watch_condition = watch_condition
  39. self.watch_nodes = watch_nodes
  40. self.name = name
  41. def get_watch_condition_dict(self):
  42. return {
  43. "id": self.watch_condition.get("condition"),
  44. "params": [{
  45. "name": param.get_parameter_name(),
  46. "value": param.value
  47. } for param in self.watch_condition.get("params")]
  48. }
  49. class _ConditionParameterValue:
  50. """Condition parameter data container"""
  51. def __init__(self, parameter, value):
  52. self.parameter = parameter
  53. self.value = value
  54. def get_parameter_name(self):
  55. return self.parameter.name
  56. def recommend_watchpoints(condition_mgr: ConditionMgr, multi_card_graph_stream, condition_context):
  57. """
  58. Recommend watchpoints.
  59. Args:
  60. condition_mgr (ConditionMgr): Condition manager instance.
  61. multi_card_graph_stream (GraphHandler): Multi card graph handler instance.
  62. condition_context (ConditionContext): Context for condition.
  63. Returns:
  64. list[WatchPointData], watch points to be created.
  65. """
  66. watch_points = []
  67. if not multi_card_graph_stream.has_graph:
  68. logger.warning("Given graph is None.")
  69. return watch_points
  70. if not settings.ENABLE_RECOMMENDED_WATCHPOINTS:
  71. return watch_points
  72. # add weight watch points
  73. merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, multi_card_graph_stream)
  74. _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context)
  75. _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context)
  76. # Because we cannot identify trainable weights currently, weight_no_change and weight_change_too_small will not be
  77. # recommended.
  78. trainable_weight_nodes = []
  79. _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_points, condition_context)
  80. _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context)
  81. # add gradient watch points
  82. merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, multi_card_graph_stream)
  83. _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context)
  84. # add tensor watch points
  85. merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, multi_card_graph_stream)
  86. _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context)
  87. _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context)
  88. _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context)
  89. # add activation watch points
  90. merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
  91. ActivationFuncEnum.TANH.value)
  92. _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
  93. ActivationFuncEnum.TANH.value)
  94. merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
  95. ActivationFuncEnum.SIGMOID.value)
  96. _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
  97. ActivationFuncEnum.SIGMOID.value)
  98. merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
  99. [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value])
  100. _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
  101. ActivationFuncEnum.RELU.value)
  102. return watch_points
  103. def _recommend_tensor_all_zero(basic_info_nodes, condition_mgr, watch_points, condition_context):
  104. """Recommend tensor all zero watchpoint."""
  105. if not basic_info_nodes:
  106. return
  107. if not condition_mgr.has_condition(ConditionIdEnum.TENSOR_ALL_ZERO.value, condition_context):
  108. return
  109. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.TENSOR_ALL_ZERO.value)
  110. tensor_all_zero_watchpoint = _WatchPointData(
  111. watch_condition={
  112. "condition": condition.id,
  113. "params": [_ConditionParameterValue(
  114. parameter=condition.get_parameter_definition("zero_percentage_ge"),
  115. value=100 # set default value to 100
  116. )]
  117. },
  118. watch_nodes=basic_info_nodes.copy(),
  119. name='recommend_tensor_all_zero_watchpoint'
  120. )
  121. watch_points.append(tensor_all_zero_watchpoint)
  122. def _recommend_tensor_overflow(basic_info_nodes, condition_mgr, watch_points, condition_context):
  123. """Recommend tensor general overflow watchpoint."""
  124. if not basic_info_nodes:
  125. return
  126. if not condition_mgr.has_condition(ConditionIdEnum.TENSOR_OVERFLOW.value, condition_context):
  127. return
  128. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.TENSOR_OVERFLOW.value)
  129. overflow_watchpoint = _WatchPointData(
  130. watch_condition={
  131. "condition": condition.id,
  132. "params": []
  133. },
  134. watch_nodes=basic_info_nodes.copy(),
  135. name='recommend_tensor_overflow_watchpoint'
  136. )
  137. watch_points.append(overflow_watchpoint)
  138. def _recommend_operator_overflow(basic_info_nodes, condition_mgr, watch_points, condition_context):
  139. """Recommend tensor overflow watchpoint."""
  140. if not basic_info_nodes:
  141. return
  142. if not condition_mgr.has_condition(ConditionIdEnum.OPERATOR_OVERFLOW.value, condition_context):
  143. return
  144. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.OPERATOR_OVERFLOW.value)
  145. overflow_d_watchpoint = _WatchPointData(
  146. watch_condition={
  147. "condition": condition.id,
  148. "params": []
  149. },
  150. watch_nodes=basic_info_nodes.copy(),
  151. name='recommend_operator_overflow_watchpoint'
  152. )
  153. watch_points.append(overflow_d_watchpoint)
  154. def _recommend_gradient_vanishing(basic_info_nodes, condition_mgr, watch_points, condition_context):
  155. """Recommend gradient vanishing watchpoint."""
  156. if not basic_info_nodes:
  157. return
  158. if not condition_mgr.has_condition(ConditionIdEnum.GRADIENT_VANISHING.value, condition_context):
  159. return
  160. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.GRADIENT_VANISHING.value)
  161. gradient_vanishing_watchpoint = _WatchPointData(
  162. watch_condition={
  163. "condition": condition.id,
  164. "params": [_ConditionParameterValue(
  165. parameter=condition.get_parameter_definition("abs_mean_lt"),
  166. value=1e-9 # set default value to 1e-9
  167. )]
  168. },
  169. watch_nodes=basic_info_nodes.copy(),
  170. name='recommend_gradient_vanishing_watchpoint'
  171. )
  172. watch_points.append(gradient_vanishing_watchpoint)
  173. def _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context):
  174. """Recommend weight change too small watchpoint."""
  175. if not trainable_weight_nodes:
  176. return
  177. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value, condition_context):
  178. return
  179. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value)
  180. weight_change_too_small_watchpoint = _WatchPointData(
  181. watch_condition={
  182. "condition": condition.id,
  183. "params": [
  184. _ConditionParameterValue(
  185. parameter=condition.get_parameter_definition("abs_mean_update_ratio_lt"),
  186. value=1.0e-4 # set default value to 1.0e-4
  187. ),
  188. ]
  189. },
  190. watch_nodes=trainable_weight_nodes,
  191. name='recommend_weight_change_too_small_watchpoint'
  192. )
  193. watch_points.append(weight_change_too_small_watchpoint)
  194. def _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_points, condition_context):
  195. """Recommend weight not changed watchpoint."""
  196. if not trainable_weight_nodes:
  197. return
  198. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_NOT_CHANGED.value, condition_context):
  199. return
  200. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_NOT_CHANGED.value)
  201. weight_no_change_watchpoint = _WatchPointData(
  202. watch_condition={
  203. "condition": condition.id,
  204. "params": [
  205. _ConditionParameterValue(
  206. parameter=condition.get_parameter_definition("rtol"),
  207. value=1.0e-5 # set default value to 1.0e-5
  208. ),
  209. _ConditionParameterValue(
  210. parameter=condition.get_parameter_definition("atol"),
  211. value=1.0e-8 # set default value to 1.0e-8
  212. ),
  213. ]
  214. },
  215. watch_nodes=trainable_weight_nodes,
  216. name='recommend_weight_not_changed_watchpoint'
  217. )
  218. watch_points.append(weight_no_change_watchpoint)
  219. def _recommend_weight_change_too_large(basic_info_nodes, condition_mgr, watch_points, condition_context):
  220. """Recommend weight change too large watchpoint."""
  221. if not basic_info_nodes:
  222. return
  223. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value, condition_context):
  224. return
  225. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value)
  226. weight_initialization_watchpoint = _WatchPointData(
  227. watch_condition={
  228. "condition": condition.id,
  229. "params": [_ConditionParameterValue(
  230. parameter=condition.get_parameter_definition("abs_mean_update_ratio_gt"),
  231. value=1 # set default value to 1
  232. )]
  233. },
  234. watch_nodes=basic_info_nodes.copy(),
  235. name='recommend_weight_change_too_large_watchpoint'
  236. )
  237. watch_points.append(weight_initialization_watchpoint)
  238. def _recommend_weight_initialization(basic_info_nodes, condition_mgr, watch_points, condition_context):
  239. """Recommend weight initialization watchpoint."""
  240. if not basic_info_nodes:
  241. return
  242. if not condition_mgr.has_condition(ConditionIdEnum.WEIGHT_INITIALIZATION.value, condition_context):
  243. return
  244. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.WEIGHT_INITIALIZATION.value)
  245. weight_initialization_watchpoint = _WatchPointData(
  246. watch_condition={
  247. "condition": condition.id,
  248. "params": [_ConditionParameterValue(
  249. parameter=condition.get_parameter_definition("zero_percentage_ge"),
  250. value=100 # set default value to 100
  251. )]
  252. },
  253. watch_nodes=basic_info_nodes.copy(),
  254. name='recommend_weight_initialization_watchpoint'
  255. )
  256. watch_points.append(weight_initialization_watchpoint)
  257. def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, condition_context, activation_func):
  258. """Recommend activation range watchpoint."""
  259. if not basic_info_nodes:
  260. return
  261. if not condition_mgr.has_condition(ConditionIdEnum.ACTIVATION_RANGE.value, condition_context):
  262. return
  263. condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.ACTIVATION_RANGE.value)
  264. params = _get_recommend_activation_params(condition, activation_func)
  265. activation_range_watchpoint = _WatchPointData(
  266. watch_condition={
  267. "condition": condition.id,
  268. "params": params
  269. },
  270. watch_nodes=basic_info_nodes.copy(),
  271. name='recommend_{}_activation_range_watchpoint'.format(activation_func.lower())
  272. )
  273. watch_points.append(activation_range_watchpoint)
  274. def get_basic_node_info(node_category, multi_card_graph_stream, activation_func=None):
  275. """Get node merged info."""
  276. nodes_for_devices = {}
  277. has_node = False
  278. for rank_id, graph_stream in multi_card_graph_stream.graph_handlers.items():
  279. basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func)
  280. merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph)
  281. merged_info = _add_graph_name(merged_info, graph_stream)
  282. nodes_for_devices[rank_id] = merged_info
  283. has_node = has_node or merged_info
  284. if has_node:
  285. return nodes_for_devices
  286. return {}
  287. def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None):
  288. """Get node basic info by node category."""
  289. pattern = {'node_category': node_category}
  290. if activation_func:
  291. pattern['condition'] = {'activation_func': activation_func}
  292. all_graph_nodes = graph_stream.search_in_graph(pattern)
  293. return all_graph_nodes
  294. def _convert_tree_to_node_list(node_tree, node_list):
  295. """Convert WatchNodeTree to Node list."""
  296. if node_tree.watch_status in [WatchNodeTree.NOT_WATCH, WatchNodeTree.INVALID]:
  297. logger.debug("The watch_status of node: %s is not_watch or invalid.", node_tree.node_name)
  298. return
  299. if node_tree.watch_status == WatchNodeTree.TOTAL_WATCH:
  300. node_basic_info = NodeBasicInfo(name=node_tree.node_name, full_name=node_tree.full_name,
  301. type=node_tree.node_type)
  302. node_list.append(node_basic_info)
  303. return
  304. if node_tree.watch_status == WatchNodeTree.PARTIAL_WATCH:
  305. for _, sub_tree in node_tree.get_children():
  306. _convert_tree_to_node_list(sub_tree, node_list)
  307. def _update_watch_status(node_tree, graph):
  308. """Update the watch_status, if all sub_nodes of a WatchNodeTree are total_watch,
  309. then the WatchNodeTree is changed to total_watch status."""
  310. tmp_node_queue = Queue.Queue()
  311. tmp_node_queue.put(node_tree)
  312. # watch node list in layer order
  313. watch_tree_list = []
  314. while not tmp_node_queue.empty():
  315. cur_tree = tmp_node_queue.get()
  316. watch_tree_list.append(cur_tree)
  317. for _, sub_tree in cur_tree.get_children():
  318. tmp_node_queue.put(sub_tree)
  319. # update the watch_status from bottom to top
  320. while watch_tree_list:
  321. cur_tree = watch_tree_list.pop()
  322. node_name = cur_tree.node_name
  323. logger.debug("Update status of node: %s.", node_name)
  324. # if node_name is "", it is the root node, which is not in normal_node_map
  325. if not node_name:
  326. continue
  327. sub_count = graph.normal_node_map.get(node_name).subnode_count
  328. # if the children_count of WatchNodeTree is less than the responding subnode_count in the graph,
  329. # its watch_status must be partial_watch
  330. if cur_tree.get_children_count() < sub_count:
  331. continue
  332. is_all_chosen = True
  333. for _, sub_tree in cur_tree.get_children():
  334. if sub_tree.watch_status != WatchNodeTree.TOTAL_WATCH:
  335. is_all_chosen = False
  336. break
  337. if is_all_chosen:
  338. cur_tree.watch_status = WatchNodeTree.TOTAL_WATCH
  339. def _merge_nodes(leaf_nodes, graph):
  340. """Merge nodes in one graph."""
  341. watch_node_tree = WatchNodeTree()
  342. for node in leaf_nodes:
  343. watch_node_tree.add_node(node.name, node.type, node.full_name)
  344. _update_watch_status(watch_node_tree, graph)
  345. out_nodes = []
  346. _convert_tree_to_node_list(watch_node_tree, out_nodes)
  347. logger.debug("out_nodes: %s", out_nodes)
  348. return out_nodes
  349. def _add_graph_name(nodes, graph_stream):
  350. """Add graph_name in node.name."""
  351. if len(graph_stream.graph) > 1:
  352. return nodes
  353. graph_name = graph_stream.graph_names[0]
  354. output_nodes = []
  355. for node in nodes:
  356. node_basic_info = graph_stream.construct_node_basic_info(
  357. full_name=node.full_name, graph_name=graph_name, node_name=node.name, node_type=node.type)
  358. output_nodes.append(node_basic_info)
  359. return output_nodes
  360. def _sigmoid(value):
  361. """Calculate the sigmoid of value."""
  362. return 1.0 / (1.0 + math.exp(-value))
  363. def _get_recommend_activation_params(condition, activation_func):
  364. """Get recommend params for tanh, sigmoid and relu activation function."""
  365. params = []
  366. if activation_func == ActivationFuncEnum.TANH.value:
  367. # The recommend params for Tanh: The percentage of value in range (tanh(-8.8), tanh(8.8)) is lower than 0.1%
  368. params = [
  369. _ConditionParameterValue(
  370. parameter=condition.get_parameter_definition("range_percentage_lt"),
  371. value=0.1
  372. ),
  373. _ConditionParameterValue(
  374. parameter=condition.get_parameter_definition("range_start_inclusive"),
  375. value=math.tanh(-8.8)
  376. ),
  377. _ConditionParameterValue(
  378. parameter=condition.get_parameter_definition("range_end_inclusive"),
  379. value=math.tanh(8.8)
  380. )]
  381. if activation_func == ActivationFuncEnum.SIGMOID.value:
  382. # The recommend params for Sigmoid:
  383. # The percentage of value in range (sigmoid(-16.2)), sigmoid(16.2)) is lower than 0.1%
  384. params = [
  385. _ConditionParameterValue(
  386. parameter=condition.get_parameter_definition("range_percentage_lt"),
  387. value=0.1
  388. ),
  389. _ConditionParameterValue(
  390. parameter=condition.get_parameter_definition("range_start_inclusive"),
  391. value=_sigmoid(-16.2)
  392. ),
  393. _ConditionParameterValue(
  394. parameter=condition.get_parameter_definition("range_end_inclusive"),
  395. value=_sigmoid(16.2)
  396. )]
  397. if activation_func == ActivationFuncEnum.RELU.value:
  398. # The recommend params for ReLU:
  399. # The percentage of value in range (-1, 0) is greater than 99.9%
  400. params = [
  401. _ConditionParameterValue(
  402. parameter=condition.get_parameter_definition("range_percentage_gt"),
  403. value=99.9
  404. ),
  405. _ConditionParameterValue(
  406. parameter=condition.get_parameter_definition("range_start_inclusive"),
  407. value=-1
  408. ),
  409. _ConditionParameterValue(
  410. parameter=condition.get_parameter_definition("range_end_inclusive"),
  411. value=0
  412. )]
  413. return params