You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_graph.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the basic graph."""
  16. from collections import deque
  17. from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph
  18. from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
  19. from mindinsight.debugger.common.exceptions.exceptions import \
  20. DebuggerNodeNotInGraphError, DebuggerParamValueError
  21. from mindinsight.debugger.common.log import logger as log
  22. from .node import NodeTree
  23. class DebuggerGraph(MSGraph):
  24. """The `DebuggerGraph` object provides interfaces to describe a debugger graph."""
  25. def __init__(self):
  26. super(DebuggerGraph, self).__init__()
  27. self._node_tree = None
  28. def get_node_name_by_full_name(self, full_name):
  29. """Get node name by full names."""
  30. inner_name = self._full_name_map_name.get(full_name, '')
  31. if not inner_name:
  32. log.warning("Node %s does not find the relative inner node name.", full_name)
  33. return inner_name
  34. def get_full_name_by_node_name(self, node_name):
  35. """Get full name by node name for leaf nodes."""
  36. node = self._normal_node_map.get(node_name)
  37. if not node:
  38. log.warning("Node %s is not leaf node.", node_name)
  39. return node.full_name if node else ''
  40. def get_nodes(self, searched_node_list):
  41. """
  42. Search node names by a given pattern.
  43. Args:
  44. searched_node_list (list[Node]): A list of leaf nodes that
  45. matches the given search pattern.
  46. Returns:
  47. A list of dict including the searched nodes.
  48. [{
  49. "name": "Default",
  50. "type": "name_scope",
  51. "nodes": [{
  52. "name": "Default/Conv2D1",
  53. "type": "name_scope",
  54. "nodes": [{
  55. ...
  56. }]
  57. }]
  58. },
  59. {
  60. "name": "Gradients",
  61. "type": "name_scope",
  62. "nodes": [{
  63. "name": "Gradients/Default",
  64. "type": "name_scope",
  65. "nodes": [{
  66. ...
  67. }]
  68. }]
  69. """
  70. # save the node in the NodeTree
  71. self._node_tree = NodeTree()
  72. for node in searched_node_list:
  73. self._build_node_tree(node.name, node.type)
  74. # get the searched nodes in the NodeTree and reorganize them
  75. searched_list = []
  76. self._traverse_node_tree(self._node_tree, searched_list)
  77. return searched_list
  78. def search_nodes_by_pattern(self, pattern):
  79. """
  80. Search node names by a given pattern.
  81. Args:
  82. pattern (Union[str, None]): The pattern of the node to search,
  83. if None, return all node names.
  84. Returns:
  85. list[(str, str)], a list of tuple (node name, node type).
  86. """
  87. if pattern is not None:
  88. pattern = pattern.lower()
  89. searched_nodes = [
  90. node for name, node in self._leaf_nodes.items()
  91. if pattern in name.lower()
  92. ]
  93. else:
  94. searched_nodes = [node for name, node in self._leaf_nodes.items()]
  95. return searched_nodes
  96. def _build_node_tree(self, node_name, node_type):
  97. """Build node tree."""
  98. scope_names = node_name.split('/')
  99. cur_node = self._node_tree
  100. for scope_name in scope_names[:-1]:
  101. sub_node = cur_node.get(scope_name)
  102. if not sub_node:
  103. sub_node = cur_node.add(scope_name)
  104. cur_node = sub_node
  105. cur_node.add(scope_names[-1], node_type)
  106. def _traverse_node_tree(self, cur_node, search_node_list):
  107. """Traverse the watch nodes and update the total watched node list."""
  108. if not cur_node.get_children():
  109. return
  110. for _, sub_node in cur_node.get_children():
  111. sub_nodes = []
  112. self._traverse_node_tree(sub_node, sub_nodes)
  113. sub_node_dict = {
  114. 'name': sub_node.node_name,
  115. 'type': sub_node.node_type,
  116. 'nodes': sub_nodes
  117. }
  118. search_node_list.append(sub_node_dict)
  119. def get_node_type(self, node_name):
  120. """
  121. Get the type of the node.
  122. Args:
  123. node_name (str): The full name of the node with its scope.
  124. Returns:
  125. A string, leaf or name_scope.
  126. """
  127. if node_name and not self.exist_node(name=node_name):
  128. raise DebuggerNodeNotInGraphError(node_name=node_name)
  129. node = self._leaf_nodes.get(node_name)
  130. if node is not None:
  131. node_type = node.type
  132. else:
  133. node_type = NodeTypeEnum.NAME_SCOPE.value
  134. return node_type
  135. def get_tensor_history(self, node_name, depth=0):
  136. """
  137. Get the tensor history of a specified node.
  138. Args:
  139. node_name (str): The debug name of the node.
  140. depth (int): The number of layers the user wants to trace. Default is 0.
  141. Returns:
  142. list, a list of the traced tensors' name and node type,
  143. arranged in order from leaf node to root node.
  144. int, the number of output tensors.
  145. """
  146. node = self._leaf_nodes.get(node_name)
  147. tensor_history = self._get_tensor_infos_of_node(node)
  148. cur_outputs_nums = len(tensor_history)
  149. cur_depth = 0
  150. trace_list = deque([(node, cur_depth)])
  151. while trace_list:
  152. cur_node, cur_depth = trace_list.popleft()
  153. tensors_info = self._get_input_tensors_of_node(cur_node)
  154. if tensors_info:
  155. tensor_history.extend(tensors_info)
  156. if cur_depth < depth:
  157. for name in cur_node.inputs.keys():
  158. trace_list.append((self._leaf_nodes[name], cur_depth + 1))
  159. return tensor_history, cur_outputs_nums
  160. @staticmethod
  161. def _get_tensor_infos_of_node(cur_node, slot=None):
  162. """Get tensors info of specified node."""
  163. tensors_info = []
  164. if slot is None:
  165. slots = range(cur_node.output_nums)
  166. elif slot >= 0:
  167. slots = [slot]
  168. else:
  169. log.info("Skip get tensor info for %s:%s.", cur_node.name, slot)
  170. return tensors_info
  171. for num in slots:
  172. tensor_info = {
  173. 'name': cur_node.name + ':' + str(num),
  174. 'full_name': cur_node.full_name + ':' + str(num),
  175. 'node_type': cur_node.type
  176. }
  177. tensors_info.append(tensor_info)
  178. return tensors_info
  179. def _get_input_tensors_of_node(self, cur_node):
  180. """Get input tensors of node."""
  181. tensors_info = []
  182. for name in cur_node.inputs.keys():
  183. node = self._leaf_nodes.get(name)
  184. tensor_info = self._get_tensor_infos_of_node(node)
  185. tensors_info.extend(tensor_info)
  186. return tensors_info
  187. def get_bfs_order(self):
  188. """
  189. Traverse the graph in order of breath-first search.
  190. Returns:
  191. list, including the leaf nodes arranged in BFS order.
  192. """
  193. root = self.get_default_root()
  194. log.info('Randomly choose node %s as root to do BFS.', root.name)
  195. bfs_order = []
  196. self.get_bfs_graph(root.name, bfs_order)
  197. length = len(self._leaf_nodes.keys())
  198. # Find rest un-traversed nodes
  199. for node_name, _ in self._leaf_nodes.items():
  200. if node_name not in bfs_order:
  201. self.get_bfs_graph(node_name, bfs_order)
  202. if len(bfs_order) != length:
  203. log.error("The length of bfs and leaf nodes are not equal.")
  204. msg = "Not all nodes are traversed!"
  205. raise DebuggerParamValueError(msg)
  206. return bfs_order
  207. def get_bfs_graph(self, node_name, bfs_order):
  208. """
  209. Traverse the graph in order of breath-first search.
  210. Returns:
  211. list, including the leaf nodes arranged in BFS order.
  212. """
  213. temp_list = deque()
  214. temp_list.append(node_name)
  215. while temp_list:
  216. node_name = temp_list.popleft()
  217. node = self._leaf_nodes.get(node_name)
  218. if not node:
  219. log.warning('Cannot find node %s in graph. Ignored.', node_name)
  220. continue
  221. bfs_order.append(node_name)
  222. if node.inputs:
  223. for name in node.inputs.keys():
  224. if name not in temp_list and name not in bfs_order:
  225. temp_list.append(name)
  226. if node.outputs:
  227. for name in node.outputs.keys():
  228. if name not in temp_list and name not in bfs_order:
  229. temp_list.append(name)
  230. def get_default_root(self):
  231. """
  232. Get a node as default root for BFS in graph. Using the
  233. leaf node with the smallest node id as the default root.
  234. Returns:
  235. str, the name of the default root.
  236. """
  237. default_root = None
  238. for _, item in self._leaf_nodes.items():
  239. if item.node_id == '1':
  240. default_root = item
  241. break
  242. if default_root is None:
  243. log.error("Abnormal graph. Invalid node for BFS.")
  244. msg = 'Abnormal graph. Invalid node for BFS.'
  245. raise DebuggerParamValueError(msg)
  246. return default_root