You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream."""
  16. from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  18. from mindinsight.debugger.common.log import logger as log
  19. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition
  20. WATCHPOINT_CONDITION_MAPPING = {
  21. 'INF': WatchCondition.Condition.inf,
  22. 'NAN': WatchCondition.Condition.nan,
  23. 'OVERFLOW': WatchCondition.Condition.overflow,
  24. 'MAX_GT': WatchCondition.Condition.max_gt,
  25. 'MAX_LT': WatchCondition.Condition.max_lt,
  26. 'MIN_GT': WatchCondition.Condition.min_gt,
  27. 'MIN_LT': WatchCondition.Condition.min_lt,
  28. 'MAX_MIN_GT': WatchCondition.Condition.max_min_gt,
  29. 'MAX_MIN_LT': WatchCondition.Condition.max_min_lt,
  30. 'MEAN_GT': WatchCondition.Condition.mean_gt,
  31. 'MEAN_LT': WatchCondition.Condition.mean_lt
  32. }
  33. class WatchNodeTree:
  34. """The WatchNode Node Structure."""
  35. NOT_WATCH = 0 # the scope node and the nodes below are not watched
  36. PARTIAL_WATCH = 1 # at least one node under the scope node is not watched
  37. TOTAL_WATCH = 2 # the scope node and the nodes below are all watched
  38. def __init__(self, node_name='', node_type=None, full_name='', watch_status=1):
  39. self._node_name = node_name
  40. self._full_name = full_name
  41. self._node_type = self._translate_node_type(node_type)
  42. self._watch_status = watch_status
  43. self._children = {}
  44. @property
  45. def node_name(self):
  46. """The property of node name."""
  47. return self._node_name
  48. @property
  49. def full_name(self):
  50. """The property of node name."""
  51. return self._full_name
  52. @property
  53. def node_type(self):
  54. """The property of node type."""
  55. return self._node_type
  56. @node_type.setter
  57. def node_type(self, value):
  58. """Set the node type."""
  59. self._node_type = self._translate_node_type(value)
  60. @property
  61. def watch_status(self):
  62. """The property of watch status about current node."""
  63. return self._watch_status
  64. def update_metadata(self, node_type, full_name, watch_status):
  65. """Update the metadata for watched node."""
  66. self._full_name = full_name
  67. self._node_type = self._translate_node_type(node_type)
  68. self._watch_status = watch_status
  69. @staticmethod
  70. def _translate_node_type(node_type):
  71. """Translate node type to watch node type."""
  72. flag = node_type
  73. if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value:
  74. flag = 'scope'
  75. elif node_type != NodeTypeEnum.AGGREGATION_SCOPE.value:
  76. flag = 'leaf'
  77. return flag
  78. def get(self, sub_name):
  79. """Get sub node."""
  80. return self._children.get(sub_name)
  81. def get_children(self):
  82. """Get all childrens."""
  83. for name_scope, sub_watch_node in self._children.items():
  84. yield name_scope, sub_watch_node
  85. def add_node(self, node_name, node_type, full_name=''):
  86. """
  87. Add watch node to watch node tree.
  88. Args:
  89. node_name (str): The node name.
  90. node_type (str): The node type.
  91. full_name (str): The full name of node.
  92. """
  93. log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name)
  94. scope_names = node_name.split('/', 1)
  95. if len(scope_names) == 1:
  96. target_node = self.get(node_name)
  97. if not target_node:
  98. self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH)
  99. else:
  100. target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH)
  101. return
  102. scope_name, sub_names = scope_names
  103. sub_tree = self.get(scope_name)
  104. if not sub_tree:
  105. sub_tree = self.add(scope_name, watch_status=1)
  106. sub_tree.add_node(sub_names, node_type, full_name)
  107. def add(self, name, node_type=None, full_name='', watch_status=1):
  108. """Add sub WatchPointTree."""
  109. sub_name = '/'.join([self._node_name, name]) if self._node_name else name
  110. sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status)
  111. self._children[name] = sub_tree
  112. return sub_tree
  113. def remove_node(self, node_name):
  114. """Remove sub node from current tree."""
  115. log.debug("Remove %s", node_name)
  116. scope_names = node_name.split('/', 1)
  117. sub_tree_name = scope_names[0]
  118. sub_tree = self._children.get(sub_tree_name)
  119. if not sub_tree:
  120. log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name)
  121. raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name))
  122. if len(scope_names) > 1:
  123. sub_tree.remove_node(scope_names[1])
  124. if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1:
  125. self._children.pop(sub_tree_name)
  126. self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \
  127. WatchNodeTree.NOT_WATCH
  128. class Watchpoint:
  129. """
  130. The class of watchpoint stream.
  131. Args:
  132. watchpoint_id (int): The id of Watchpoint.
  133. watch_condition (dict): The condition of Watchpoint.
  134. - condition (str): Accept `INF` or `NAN`.
  135. - param (list[float]): Not defined yet.
  136. """
  137. def __init__(self, watchpoint_id, watch_condition):
  138. self._id = watchpoint_id
  139. self._condition = watch_condition
  140. self._watch_node = WatchNodeTree()
  141. @property
  142. def watchpoint_id(self):
  143. """The property of watchpoint id."""
  144. return self._id
  145. @property
  146. def nodes(self):
  147. """The property of watch nodes."""
  148. return self._watch_node
  149. @property
  150. def condition(self):
  151. """The property of watch condition."""
  152. return self._condition
  153. def copy_nodes_from(self, other_watchpoint):
  154. """
  155. Copy nodes from other watchpoint.
  156. Args:
  157. other_watchpoint (Watchpoint): Other watchpoint.
  158. """
  159. self._watch_node = other_watchpoint.nodes
  160. def add_nodes(self, nodes):
  161. """Add node into watchcpoint."""
  162. if not nodes:
  163. log.warning("Add empty nodes.")
  164. return
  165. if not isinstance(nodes, list):
  166. nodes = [nodes]
  167. for node in nodes:
  168. self._watch_node.add_node(node.name, node.type, node.full_name)
  169. def remove_nodes(self, nodes):
  170. """Remove nodes from watchpoint."""
  171. if not nodes:
  172. return
  173. if not isinstance(nodes, list):
  174. nodes = [nodes]
  175. for node in nodes:
  176. node_name = node.split(':')[0]
  177. self._watch_node.remove_node(node_name)
  178. def get_node_status(self, node_name, node_type, full_name):
  179. """Judge if the node is in watch nodes."""
  180. scope_names = node_name.split('/')
  181. cur_node = self._watch_node
  182. status = 1
  183. for scope_name in scope_names:
  184. cur_node = cur_node.get(scope_name)
  185. if cur_node is None:
  186. status = WatchNodeTree.NOT_WATCH
  187. break
  188. if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH:
  189. status = WatchNodeTree.TOTAL_WATCH
  190. break
  191. if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name:
  192. self._watch_node.add_node(node_name, node_type, full_name)
  193. return status
  194. def get_watch_node(self, cur_watch_node, watch_node_list):
  195. """
  196. Traverse the watch nodes and add total watched node list to `watch_node_list`.
  197. Args:
  198. cur_watch_node (WatchNodeTree): The current watch node.
  199. watch_node_list (list[WatchNodeTree]): The list of total watched node.
  200. """
  201. if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH and \
  202. cur_watch_node.node_type != NodeTypeEnum.AGGREGATION_SCOPE.value:
  203. watch_node_list.append(cur_watch_node)
  204. return
  205. for _, watch_node in cur_watch_node.get_children():
  206. self.get_watch_node(watch_node, watch_node_list)
  207. def get_set_cmd(self):
  208. """Return the watchpoint in proto format."""
  209. # get watch nodes.
  210. watch_nodes = []
  211. self.get_watch_node(self._watch_node, watch_nodes)
  212. # construct SetCMD
  213. set_cmd = SetCMD()
  214. set_cmd.id = self._id
  215. set_cmd.delete = False
  216. set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get(
  217. self._condition.get('condition'))
  218. if self._condition.get('param'):
  219. # at most one param is provided
  220. set_cmd.watch_condition.value = self._condition.get('param')
  221. for watch_node in watch_nodes:
  222. event_node = set_cmd.watch_nodes.add()
  223. event_node.node_name = watch_node.full_name
  224. event_node.node_type = watch_node.node_type
  225. return set_cmd
  226. def get_watch_condition_info(self):
  227. """Get watch condition info."""
  228. watchpoint_info = {
  229. 'id': self._id,
  230. 'watch_condition': self._condition
  231. }
  232. return watchpoint_info
  233. class WatchpointHit:
  234. """The watchpoint hit structure."""
  235. def __init__(self, tensor_proto, watchpoint, node_name):
  236. self._node_name = node_name
  237. self._full_name = tensor_proto.node_name
  238. self._slot = tensor_proto.slot
  239. self._watchpoint = watchpoint
  240. @property
  241. def tensor_full_name(self):
  242. """The property of tensor full name."""
  243. tensor_name = ':'.join([self._full_name, self._slot])
  244. return tensor_name
  245. @property
  246. def tensor_name(self):
  247. """The property of tensor ui name."""
  248. tensor_name = ':'.join([self._node_name, self._slot])
  249. return tensor_name
  250. @property
  251. def watchpoint(self):
  252. """The property of watchpoint."""
  253. watchpoint = self._watchpoint.get_watch_condition_info()
  254. return watchpoint
  255. def __eq__(self, other):
  256. """Define the equal condition."""
  257. flag = self.tensor_full_name == other.tensor_full_name and self.watchpoint == other.watchpoint
  258. return flag