You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_graph.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the basic graph."""
  16. from collections import deque
  17. from copy import deepcopy
  18. from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph
  19. from mindinsight.debugger.common.exceptions.exceptions import \
  20. DebuggerNodeNotInGraphError, DebuggerParamValueError
  21. from mindinsight.debugger.common.log import LOGGER as log
  22. from .node_type_identifier import NodeTypeIdentifier
  23. def _is_match(identifier, node, condition):
  24. """Check if the node is matched to the identifier.
  25. Args:
  26. identifier (NodeTypeIdentifier): The debug name of the node.
  27. node (Node obj): The number of layers the user wants to trace. Default is 0.
  28. Returns:
  29. list, a list of the traced tensors' name and node type,
  30. arranged in order from leaf node to root node.
  31. int, the number of output tensors.
  32. """
  33. if condition:
  34. matched = identifier.is_match(node, condition)
  35. else:
  36. matched = identifier.is_match(node)
  37. return matched
  38. class DebuggerGraph(MSGraph):
  39. """The `DebuggerGraph` object provides interfaces to describe a debugger graph."""
  40. @property
  41. def leaf_nodes(self):
  42. """Return the leaf nodes."""
  43. return self._leaf_nodes
  44. @property
  45. def normal_node_map(self):
  46. """Return the normal_node_map"""
  47. return self._normal_node_map
  48. @property
  49. def node_id_map_name(self):
  50. """Return the node_id_map_name"""
  51. return self._node_id_map_name
  52. @property
  53. def const_node_temp_cache(self):
  54. """Return const_node_temp_cache"""
  55. return self._const_node_temp_cache
  56. @property
  57. def parameter_node_temp_cache(self):
  58. """Return parameter_node_temp_cache"""
  59. return self._parameter_node_temp_cache
  60. @property
  61. def full_name_map_name(self):
  62. """Return full_name_map_name"""
  63. return self._full_name_map_name
  64. def get_node_name_by_full_name(self, full_name):
  65. """Get node name by full names."""
  66. inner_name = self._full_name_map_name.get(full_name, '')
  67. if not inner_name:
  68. log.warning("Node %s does not find the relative inner node name.", full_name)
  69. return inner_name
  70. def get_full_name_by_node_name(self, node_name):
  71. """Get full name by node name."""
  72. if not node_name:
  73. return ''
  74. node = self._normal_node_map.get(node_name)
  75. if not node:
  76. log.error("Node <%s> is not in graph.", node_name)
  77. raise DebuggerNodeNotInGraphError(node_name=node_name)
  78. return node.full_name
  79. def get_node_type(self, node_name):
  80. """
  81. Get the type of the node.
  82. Args:
  83. node_name (str): The full name of the node with its scope.
  84. Returns:
  85. str, node type or name_scope.
  86. """
  87. if not node_name:
  88. return 'name_scope'
  89. node = self._normal_node_map.get(node_name)
  90. if not node:
  91. log.error("Node <%s> is not in graph.", node_name)
  92. raise DebuggerNodeNotInGraphError(node_name=node_name)
  93. return node.type
  94. def search_nodes_by_category(self, node_category, condition=None):
  95. """
  96. Search nodes by type.
  97. Args:
  98. node_category (TargetTypeEnum): The node type supported in
  99. mindinsight.conditionmgr.condition.TargetTypeEnum.
  100. condition (dict): Search condition. Default: None.
  101. - activation_func (Union[str, list[str]): The target functions. Used when node_type
  102. is TargetTypeEnum.ACTIVATION.
  103. - search_range (list[Node]): The list of nodes to be searched from.
  104. Returns:
  105. list[Node], list of nodes.
  106. """
  107. identifier = NodeTypeIdentifier(node_category.value)
  108. # get search range
  109. condition = {} if condition is None else condition
  110. search_range = condition.pop('search_range', None)
  111. if not search_range:
  112. search_range = self._leaf_nodes.values()
  113. # search match nodes
  114. matched_nodes = []
  115. for node in search_range:
  116. matched = _is_match(identifier, node, condition)
  117. if matched:
  118. matched_nodes.append(node)
  119. return matched_nodes
  120. def get_tensor_history(self, node_name, depth=0):
  121. """
  122. Get the tensor history of a specified node.
  123. Args:
  124. node_name (str): The debug name of the node.
  125. depth (int): The number of layers the user wants to trace. Default is 0.
  126. Returns:
  127. list, a list of the traced tensors' name and node type,
  128. arranged in order from leaf node to root node.
  129. int, the number of output tensors.
  130. """
  131. node = self._leaf_nodes.get(node_name)
  132. tensor_history = self._get_tensor_infos_of_node(node)
  133. cur_outputs_nums = len(tensor_history)
  134. cur_depth = 0
  135. trace_list = deque([(node, cur_depth)])
  136. while trace_list:
  137. cur_node, cur_depth = trace_list.popleft()
  138. tensors_info = self._get_input_tensors_of_node(cur_node)
  139. if tensors_info:
  140. tensor_history.extend(tensors_info)
  141. if cur_depth < depth:
  142. for name in cur_node.inputs.keys():
  143. trace_list.append((self._leaf_nodes[name], cur_depth + 1))
  144. return tensor_history, cur_outputs_nums
  145. @staticmethod
  146. def _get_tensor_infos_of_node(cur_node, slot=None):
  147. """Get tensors info of specified node."""
  148. tensors_info = []
  149. if slot is None:
  150. slots = range(cur_node.output_nums)
  151. elif slot >= 0:
  152. slots = [slot]
  153. else:
  154. log.info("Skip get tensor info for %s:%s.", cur_node.name, slot)
  155. return tensors_info
  156. for num in slots:
  157. tensor_info = {
  158. 'name': cur_node.name + ':' + str(num),
  159. 'full_name': cur_node.full_name + ':' + str(num),
  160. 'node_type': cur_node.type
  161. }
  162. tensors_info.append(tensor_info)
  163. return tensors_info
  164. def _get_input_tensors_of_node(self, cur_node):
  165. """Get input tensors of node."""
  166. tensors_info = []
  167. for name in cur_node.inputs.keys():
  168. node = self._leaf_nodes.get(name)
  169. tensor_info = self._get_tensor_infos_of_node(node)
  170. tensors_info.extend(tensor_info)
  171. return tensors_info
  172. def get_bfs_order(self):
  173. """
  174. Traverse the graph in order of breath-first search.
  175. Returns:
  176. list, including the leaf nodes arranged in BFS order.
  177. """
  178. root = self.get_default_root()
  179. log.info('Randomly choose node %s as root to do BFS.', root.name)
  180. bfs_order = []
  181. self.get_bfs_graph(root.name, bfs_order)
  182. length = len(self._leaf_nodes.keys())
  183. # Find rest un-traversed nodes
  184. for node_name, _ in self._leaf_nodes.items():
  185. if node_name not in bfs_order:
  186. self.get_bfs_graph(node_name, bfs_order)
  187. if len(bfs_order) != length:
  188. log.error("The length of bfs and leaf nodes are not equal.")
  189. msg = "Not all nodes are traversed!"
  190. raise DebuggerParamValueError(msg)
  191. return bfs_order
  192. def get_bfs_graph(self, node_name, bfs_order):
  193. """
  194. Traverse the graph in order of breath-first search.
  195. Returns:
  196. list, including the leaf nodes arranged in BFS order.
  197. """
  198. temp_list = deque()
  199. temp_list.append(node_name)
  200. while temp_list:
  201. node_name = temp_list.popleft()
  202. node = self._leaf_nodes.get(node_name)
  203. if not node:
  204. log.warning('Cannot find node %s in graph. Ignored.', node_name)
  205. continue
  206. bfs_order.append(node_name)
  207. if node.inputs:
  208. for name in node.inputs.keys():
  209. if name not in temp_list and name not in bfs_order:
  210. temp_list.append(name)
  211. if node.outputs:
  212. for name in node.outputs.keys():
  213. if name not in temp_list and name not in bfs_order:
  214. temp_list.append(name)
  215. def get_default_root(self):
  216. """
  217. Get a node as default root for BFS in graph. Using the
  218. leaf node with the smallest node id as the default root.
  219. Returns:
  220. str, the name of the default root.
  221. """
  222. default_root = None
  223. for _, item in self._leaf_nodes.items():
  224. if item.node_id == '1':
  225. default_root = item
  226. break
  227. if default_root is None:
  228. log.error("Abnormal graph. Invalid node for BFS.")
  229. msg = 'Abnormal graph. Invalid node for BFS.'
  230. raise DebuggerParamValueError(msg)
  231. return default_root
  232. def get_tensor_graph(self, node_name):
  233. """
  234. Get graph relative to a node.
  235. Args:
  236. node_name (str): Node name.
  237. Returns:
  238. dict, tensor graph, format is:
  239. {'nodes': [
  240. {'name': <node name>,
  241. 'full_name': <node full name>,
  242. 'type': <node type>
  243. 'input': <input objects>,
  244. 'output': <output objects>,
  245. 'slot': {'id': <slot id>}
  246. }
  247. ]}
  248. """
  249. graph_nodes = []
  250. cur_node = self._leaf_nodes.get(node_name)
  251. node_detail_info = cur_node.to_dict()
  252. cur_node_info = self._get_node_info_for_tensor_graph(cur_node)
  253. cur_node_info['input'] = deepcopy(node_detail_info.get('input'))
  254. cur_node_info['output'] = deepcopy(node_detail_info.get('output'))
  255. self._add_input_node_info(cur_node_info=cur_node_info, graph_nodes=graph_nodes)
  256. self._add_output_node_info(cur_node=cur_node, cur_node_info=cur_node_info, graph_nodes=graph_nodes)
  257. graph_nodes.append(cur_node_info)
  258. return {'nodes': graph_nodes}
  259. @staticmethod
  260. def _get_node_info_for_tensor_graph(node):
  261. """Get node infos for tensor graph."""
  262. node_info = {
  263. 'name': node.name,
  264. 'full_name': node.full_name,
  265. 'type': node.type,
  266. 'input': {},
  267. 'output': {},
  268. 'slots': [{'slot': str(slot)} for slot in range(node.output_nums)]
  269. }
  270. return node_info
  271. def _add_output_node_info(self, cur_node, cur_node_info, graph_nodes):
  272. """
  273. Add output node info into cur_node_info and node list.
  274. Args:
  275. cur_node (Node): The current node object.
  276. cur_node_info (dict): Current node info.
  277. graph_nodes (list[<Node info>]): The nodes in tensor graph.
  278. """
  279. output_slot_mapping = self._get_slot_mapping(cur_node)
  280. for node_name, edge_info in cur_node_info.get('output').items():
  281. edge_info['slot_mapping'] = output_slot_mapping
  282. # add output node info into graph
  283. output_node = self._leaf_nodes.get(node_name)
  284. output_node_info = self._get_node_info_for_tensor_graph(output_node)
  285. output_node_info['input'][cur_node.name] = edge_info
  286. graph_nodes.append(output_node_info)
  287. def _add_input_node_info(self, cur_node_info, graph_nodes):
  288. """
  289. Add input node info into cur_node_info and node list.
  290. Args:
  291. cur_node_info (dict): Current node info.
  292. graph_nodes (list[<Node info>]): The nodes in tensor graph.
  293. """
  294. cur_node_name = cur_node_info.get('name')
  295. for node_name, edge_info in cur_node_info.get('input').items():
  296. input_node = self._leaf_nodes.get(node_name)
  297. edge_info['slot_mapping'] = self._get_slot_mapping(input_node)
  298. # add input node info into graph
  299. input_node_info = self._get_node_info_for_tensor_graph(input_node)
  300. input_node_info['output'][cur_node_name] = edge_info
  301. graph_nodes.append(input_node_info)
  302. @staticmethod
  303. def _get_slot_mapping(input_node):
  304. """Get slot mapping between nodes."""
  305. return [[str(slot), ''] for slot in range(input_node.output_nums)]