You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_graph.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the basic graph."""
  16. from collections import deque
  17. from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph
  18. from mindinsight.debugger.common.exceptions.exceptions import \
  19. DebuggerNodeNotInGraphError, DebuggerParamValueError
  20. from mindinsight.debugger.common.log import logger as log
  21. class DebuggerGraph(MSGraph):
  22. """The `DebuggerGraph` object provides interfaces to describe a debugger graph."""
  23. def get_node_name_by_full_name(self, full_name):
  24. """Get node name by full names."""
  25. inner_name = self._full_name_map_name.get(full_name, '')
  26. if not inner_name:
  27. log.warning("Node %s does not find the relative inner node name.", full_name)
  28. return inner_name
  29. def get_full_name_by_node_name(self, node_name):
  30. """Get full name by node name for leaf nodes."""
  31. node = self._normal_node_map.get(node_name)
  32. if not node:
  33. log.warning("Node %s is not leaf node.", node_name)
  34. return node.full_name if node else ''
  35. def get_node_type(self, node_name):
  36. """
  37. Get the type of the node.
  38. Args:
  39. node_name (str): The full name of the node with its scope.
  40. Returns:
  41. A string, leaf or name_scope.
  42. """
  43. if node_name and not self.exist_node(name=node_name):
  44. raise DebuggerNodeNotInGraphError(node_name=node_name)
  45. node = self._normal_node_map.get(node_name)
  46. return node.type
  47. def get_tensor_history(self, node_name, depth=0):
  48. """
  49. Get the tensor history of a specified node.
  50. Args:
  51. node_name (str): The debug name of the node.
  52. depth (int): The number of layers the user wants to trace. Default is 0.
  53. Returns:
  54. list, a list of the traced tensors' name and node type,
  55. arranged in order from leaf node to root node.
  56. int, the number of output tensors.
  57. """
  58. node = self._leaf_nodes.get(node_name)
  59. tensor_history = self._get_tensor_infos_of_node(node)
  60. cur_outputs_nums = len(tensor_history)
  61. cur_depth = 0
  62. trace_list = deque([(node, cur_depth)])
  63. while trace_list:
  64. cur_node, cur_depth = trace_list.popleft()
  65. tensors_info = self._get_input_tensors_of_node(cur_node)
  66. if tensors_info:
  67. tensor_history.extend(tensors_info)
  68. if cur_depth < depth:
  69. for name in cur_node.inputs.keys():
  70. trace_list.append((self._leaf_nodes[name], cur_depth + 1))
  71. return tensor_history, cur_outputs_nums
  72. @staticmethod
  73. def _get_tensor_infos_of_node(cur_node, slot=None):
  74. """Get tensors info of specified node."""
  75. tensors_info = []
  76. if slot is None:
  77. slots = range(cur_node.output_nums)
  78. elif slot >= 0:
  79. slots = [slot]
  80. else:
  81. log.info("Skip get tensor info for %s:%s.", cur_node.name, slot)
  82. return tensors_info
  83. for num in slots:
  84. tensor_info = {
  85. 'name': cur_node.name + ':' + str(num),
  86. 'full_name': cur_node.full_name + ':' + str(num),
  87. 'node_type': cur_node.type
  88. }
  89. tensors_info.append(tensor_info)
  90. return tensors_info
  91. def _get_input_tensors_of_node(self, cur_node):
  92. """Get input tensors of node."""
  93. tensors_info = []
  94. for name in cur_node.inputs.keys():
  95. node = self._leaf_nodes.get(name)
  96. tensor_info = self._get_tensor_infos_of_node(node)
  97. tensors_info.extend(tensor_info)
  98. return tensors_info
  99. def get_bfs_order(self):
  100. """
  101. Traverse the graph in order of breath-first search.
  102. Returns:
  103. list, including the leaf nodes arranged in BFS order.
  104. """
  105. root = self.get_default_root()
  106. log.info('Randomly choose node %s as root to do BFS.', root.name)
  107. bfs_order = []
  108. self.get_bfs_graph(root.name, bfs_order)
  109. length = len(self._leaf_nodes.keys())
  110. # Find rest un-traversed nodes
  111. for node_name, _ in self._leaf_nodes.items():
  112. if node_name not in bfs_order:
  113. self.get_bfs_graph(node_name, bfs_order)
  114. if len(bfs_order) != length:
  115. log.error("The length of bfs and leaf nodes are not equal.")
  116. msg = "Not all nodes are traversed!"
  117. raise DebuggerParamValueError(msg)
  118. return bfs_order
  119. def get_bfs_graph(self, node_name, bfs_order):
  120. """
  121. Traverse the graph in order of breath-first search.
  122. Returns:
  123. list, including the leaf nodes arranged in BFS order.
  124. """
  125. temp_list = deque()
  126. temp_list.append(node_name)
  127. while temp_list:
  128. node_name = temp_list.popleft()
  129. node = self._leaf_nodes.get(node_name)
  130. if not node:
  131. log.warning('Cannot find node %s in graph. Ignored.', node_name)
  132. continue
  133. bfs_order.append(node_name)
  134. if node.inputs:
  135. for name in node.inputs.keys():
  136. if name not in temp_list and name not in bfs_order:
  137. temp_list.append(name)
  138. if node.outputs:
  139. for name in node.outputs.keys():
  140. if name not in temp_list and name not in bfs_order:
  141. temp_list.append(name)
  142. def get_default_root(self):
  143. """
  144. Get a node as default root for BFS in graph. Using the
  145. leaf node with the smallest node id as the default root.
  146. Returns:
  147. str, the name of the default root.
  148. """
  149. default_root = None
  150. for _, item in self._leaf_nodes.items():
  151. if item.node_id == '1':
  152. default_root = item
  153. break
  154. if default_root is None:
  155. log.error("Abnormal graph. Invalid node for BFS.")
  156. msg = 'Abnormal graph. Invalid node for BFS.'
  157. raise DebuggerParamValueError(msg)
  158. return default_root