You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream."""
  16. import copy
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  18. from mindinsight.debugger.common.log import LOGGER as log
  19. from mindinsight.debugger.common.utils import is_scope_type, is_cst_type
  20. from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
  21. from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
  22. from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
  23. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition
  24. WATCHPOINT_CONDITION_MAPPING = {
  25. ConditionIdEnum.ACTIVATION_RANGE.value: WatchCondition.Condition.tensor_range,
  26. ConditionIdEnum.GRADIENT_EXPLODING.value: WatchCondition.Condition.tensor_general_overflow,
  27. ConditionIdEnum.GRADIENT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
  28. ConditionIdEnum.GRADIENT_VANISHING.value: WatchCondition.Condition.tensor_too_small,
  29. ConditionIdEnum.INF.value: WatchCondition.Condition.inf,
  30. ConditionIdEnum.MAX_GT.value: WatchCondition.Condition.max_gt,
  31. ConditionIdEnum.MAX_LT.value: WatchCondition.Condition.max_lt,
  32. ConditionIdEnum.MAX_MIN_GT.value: WatchCondition.Condition.max_min_gt,
  33. ConditionIdEnum.MAX_MIN_LT.value: WatchCondition.Condition.max_min_lt,
  34. ConditionIdEnum.MEAN_GT.value: WatchCondition.Condition.mean_gt,
  35. ConditionIdEnum.MEAN_LT.value: WatchCondition.Condition.mean_lt,
  36. ConditionIdEnum.MIN_GT.value: WatchCondition.Condition.min_gt,
  37. ConditionIdEnum.MIN_LT.value: WatchCondition.Condition.min_lt,
  38. ConditionIdEnum.NAN.value: WatchCondition.Condition.nan,
  39. ConditionIdEnum.OPERATOR_OVERFLOW.value: WatchCondition.Condition.overflow,
  40. ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value: WatchCondition.Condition.overflow,
  41. ConditionIdEnum.TENSOR_ALL_ZERO.value: WatchCondition.Condition.tensor_all_zero,
  42. ConditionIdEnum.TENSOR_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization,
  43. ConditionIdEnum.TENSOR_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
  44. ConditionIdEnum.TENSOR_RANGE.value: WatchCondition.Condition.tensor_range,
  45. ConditionIdEnum.TENSOR_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
  46. ConditionIdEnum.TENSOR_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small,
  47. ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value: WatchCondition.Condition.tensor_change_too_large,
  48. ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value: WatchCondition.Condition.tensor_change_too_small,
  49. ConditionIdEnum.WEIGHT_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization,
  50. ConditionIdEnum.WEIGHT_NOT_CHANGED.value: WatchCondition.Condition.tensor_not_changed,
  51. ConditionIdEnum.WEIGHT_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
  52. ConditionIdEnum.WEIGHT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
  53. ConditionIdEnum.WEIGHT_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small
  54. }
  55. class WatchNodeTree:
  56. """The WatchNode Node Structure."""
  57. INVALID = -1 # the scope node and the nodes below are invalid
  58. NOT_WATCH = 0 # the scope node and the nodes below are not watched
  59. PARTIAL_WATCH = 1 # at least one node under the scope node is not watched
  60. TOTAL_WATCH = 2 # the scope node and the nodes below are all watched
  61. def __init__(self, node_name='', node_type=None, full_name='', watch_status=1):
  62. self._node_name = node_name
  63. self._full_name = full_name
  64. self._node_type = self._translate_node_type(node_type)
  65. self._watch_status = watch_status
  66. self._children = {}
  67. @property
  68. def node_name(self):
  69. """The property of node name."""
  70. return self._node_name
  71. @property
  72. def full_name(self):
  73. """The property of node name."""
  74. return self._full_name
  75. @property
  76. def node_type(self):
  77. """The property of node type."""
  78. return self._node_type
  79. @node_type.setter
  80. def node_type(self, value):
  81. """Set the node type."""
  82. self._node_type = self._translate_node_type(value)
  83. @property
  84. def watch_status(self):
  85. """The property of watch status about current node."""
  86. return self._watch_status
  87. def update_metadata(self, node_type, full_name, watch_status):
  88. """Update the metadata for watched node."""
  89. self._full_name = full_name
  90. self._node_type = self._translate_node_type(node_type)
  91. self._watch_status = watch_status
  92. @staticmethod
  93. def _translate_node_type(node_type):
  94. """Translate node type to watch node type."""
  95. flag = node_type
  96. if not node_type or is_scope_type(node_type):
  97. flag = 'scope'
  98. return flag
  99. def get(self, sub_name):
  100. """Get sub node."""
  101. return self._children.get(sub_name)
  102. def get_children(self):
  103. """Get all children."""
  104. for name_scope, sub_watch_node in self._children.items():
  105. yield name_scope, sub_watch_node
  106. def add_node(self, node_name, node_type, full_name=''):
  107. """
  108. Add watch node to watch node tree.
  109. Args:
  110. node_name (str): The node name.
  111. node_type (str): The node type.
  112. full_name (str): The full name of node.
  113. """
  114. log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name)
  115. scope_names = node_name.split('/', 1)
  116. if len(scope_names) == 1:
  117. target_node = self.get(node_name)
  118. if not target_node:
  119. self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH)
  120. else:
  121. target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH)
  122. return
  123. scope_name, sub_names = scope_names
  124. sub_tree = self.get(scope_name)
  125. if not sub_tree:
  126. sub_tree = self.add(scope_name, watch_status=1)
  127. sub_tree.add_node(sub_names, node_type, full_name)
  128. def add(self, name, node_type=None, full_name='', watch_status=1):
  129. """Add sub WatchPointTree."""
  130. sub_name = '/'.join([self._node_name, name]) if self._node_name else name
  131. sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status)
  132. self._children[name] = sub_tree
  133. return sub_tree
  134. def remove_node(self, node_name):
  135. """Remove sub node from current tree."""
  136. log.debug("Remove %s", node_name)
  137. scope_names = node_name.split('/', 1)
  138. sub_tree_name = scope_names[0]
  139. sub_tree = self._children.get(sub_tree_name)
  140. if not sub_tree:
  141. log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name)
  142. raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name))
  143. if len(scope_names) > 1:
  144. sub_tree.remove_node(scope_names[1])
  145. if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1:
  146. self._children.pop(sub_tree_name)
  147. self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \
  148. WatchNodeTree.NOT_WATCH
  149. class Watchpoint:
  150. """
  151. The class of watchpoint stream.
  152. Args:
  153. watchpoint_id (int): The id of Watchpoint.
  154. watch_condition (dict): The condition of Watchpoint.
  155. - condition (str): Accept `INF` or `NAN`.
  156. - param (list[float]): Not defined yet.
  157. """
  158. def __init__(self, watchpoint_id, watch_condition, name=None):
  159. self._id = watchpoint_id
  160. self._condition = watch_condition
  161. self._watch_node = WatchNodeTree()
  162. self.name = name
  163. @property
  164. def watchpoint_id(self):
  165. """The property of watchpoint id."""
  166. return self._id
  167. @property
  168. def nodes(self):
  169. """The property of watch nodes."""
  170. return self._watch_node
  171. @property
  172. def condition(self):
  173. """The property of watch condition."""
  174. return self._condition
  175. def copy_nodes_from(self, other_watchpoint, deep_copy=False):
  176. """
  177. Copy nodes from other watchpoint.
  178. Args:
  179. other_watchpoint (Watchpoint): Other watchpoint.
  180. deep_copy (bool): Whether using deepcopy.
  181. """
  182. if deep_copy:
  183. self._watch_node = copy.deepcopy(other_watchpoint.nodes)
  184. else:
  185. self._watch_node = other_watchpoint.nodes
  186. def add_nodes(self, nodes):
  187. """Add node into watchpoint."""
  188. if not nodes:
  189. log.warning("Add empty nodes.")
  190. return
  191. if not isinstance(nodes, list):
  192. nodes = [nodes]
  193. for node in nodes:
  194. self._watch_node.add_node(node.name, node.type, node.full_name)
  195. def remove_nodes(self, nodes):
  196. """Remove nodes from watchpoint."""
  197. if not nodes:
  198. return
  199. if not isinstance(nodes, list):
  200. nodes = [nodes]
  201. for node in nodes:
  202. self._watch_node.remove_node(node.name)
  203. def get_node_status(self, node_name, node_type, full_name):
  204. """Judge if the node is in watch nodes."""
  205. if is_cst_type(node_type):
  206. return WatchNodeTree.INVALID
  207. scope_names = node_name.split('/')
  208. cur_node = self._watch_node
  209. status = 1
  210. for scope_name in scope_names:
  211. cur_node = cur_node.get(scope_name)
  212. if cur_node is None:
  213. status = WatchNodeTree.NOT_WATCH
  214. break
  215. if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH:
  216. status = WatchNodeTree.TOTAL_WATCH
  217. break
  218. if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name:
  219. self._watch_node.add_node(node_name, node_type, full_name)
  220. return status
  221. def _get_watch_node(self, cur_watch_node, watch_node_list):
  222. """
  223. Traverse the watch nodes and add total watched node list to `watch_node_list`.
  224. Args:
  225. cur_watch_node (WatchNodeTree): The current watch node.
  226. watch_node_list (list[NodeBasicInfo]): The list of watch node basic infos.
  227. """
  228. if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH:
  229. node_info = NodeBasicInfo(name=cur_watch_node.node_name,
  230. full_name=cur_watch_node.full_name,
  231. type=cur_watch_node.node_type)
  232. watch_node_list.append(node_info)
  233. return
  234. for _, watch_node in cur_watch_node.get_children():
  235. self._get_watch_node(watch_node, watch_node_list)
  236. def get_watch_nodes(self):
  237. """
  238. Get the name of all total watched nodes.
  239. Returns:
  240. list[NodeBasicInfo], the list of watch node basic infos.
  241. """
  242. watch_nodes = []
  243. self._get_watch_node(self._watch_node, watch_nodes)
  244. return watch_nodes
  245. def get_pending_cmd(self, watch_nodes):
  246. """Return the watchpoint in proto format."""
  247. # construct SetCMD
  248. condition_id = self._condition.get('id')
  249. set_cmd = SetCMD()
  250. set_cmd.id = self._id
  251. set_cmd.delete = False
  252. set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get(condition_id)
  253. condition_mgr = ConditionMgr()
  254. condition = condition_mgr.get_condition(condition_id)
  255. param_dict = {
  256. param.get('name'): param for param in self._condition.get('params')
  257. }
  258. for param_name in condition.ordered_parameter_names:
  259. param = param_dict.get(param_name)
  260. if param:
  261. param_proto = set_cmd.watch_condition.params.add()
  262. param_proto.name = param.get('name')
  263. param_proto.value = param.get('value')
  264. param_proto.disabled = False
  265. # Only one parameter of condition in old mindspore version.
  266. set_cmd.watch_condition.value = param.get('value')
  267. else:
  268. param_proto = set_cmd.watch_condition.params.add()
  269. param_proto.name = param_name
  270. param_proto.disabled = True
  271. for watch_node in watch_nodes:
  272. event_node = set_cmd.watch_nodes.add()
  273. event_node.node_name = watch_node.full_name
  274. event_node.node_type = watch_node.type
  275. return set_cmd
  276. def get_watch_condition_info(self):
  277. """Get watch condition info."""
  278. watchpoint_info = {
  279. 'id': self._id,
  280. 'watch_condition': self._condition
  281. }
  282. if self.name:
  283. watchpoint_info['name'] = self.name
  284. return watchpoint_info
  285. class WatchpointHit:
  286. """The watchpoint hit structure."""
  287. def __init__(self, tensor_proto, watchpoint, node_name, graph_name):
  288. self._full_name = tensor_proto.node_name
  289. self._watchpoint = watchpoint
  290. self.node_name = node_name
  291. self.slot = tensor_proto.slot
  292. self.graph_name = graph_name
  293. self.error_code = 0
  294. @property
  295. def tensor_full_name(self):
  296. """The property of tensor full name."""
  297. tensor_name = ':'.join([self._full_name, self.slot])
  298. return tensor_name
  299. @property
  300. def watchpoint(self):
  301. """The property of watchpoint."""
  302. watchpoint = self._watchpoint.get_watch_condition_info()
  303. return watchpoint
  304. def __eq__(self, other):
  305. """Define the equal condition."""
  306. flag = self.tensor_full_name == other.tensor_full_name \
  307. and self.watchpoint == other.watchpoint \
  308. and self.graph_name == other.graph_name
  309. return flag