You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the basic graph.
  17. """
  18. import time
  19. from enum import Enum
  20. from collections import defaultdict
  21. from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
  22. from mindinsight.datavisual.common.log import logger
  23. from mindinsight.utils.exceptions import ParamMissError
  24. from mindinsight.utils.exceptions import ParamValueError
  25. from .node import NodeTypeEnum
  26. from .node import Node
  27. class EdgeTypeEnum(Enum):
  28. """Node edge type enum."""
  29. CONTROL = 'control'
  30. DATA = 'data'
  31. class Graph:
  32. """The `Graph` object is used to describe a graph file."""
  33. # Limit the size of a single attribute value per node to avoid storing too much data
  34. MAX_NODE_ATTRIBUTE_VALUE_BYTES = 1024
  35. # In the same scope, the number of children of the same type exceeds this threshold, and we will combine them.
  36. MIN_GROUP_NODE_COUNT = 5
  37. def __init__(self):
  38. # Used to cache all nodes, and the key is node name, value is `Node` object.
  39. self._normal_node_map = {}
  40. self._node_id_map_name = {}
  41. # The additional caching of Const and Parameter is to handle the Const
  42. # and Parameter nodes separately later.
  43. self._const_node_temp_cache = {}
  44. self._parameter_node_temp_cache = {}
  45. self._leaf_nodes = {}
  46. self._full_name_map_name = {}
  47. def build_graph(self, proto_data):
  48. """This method is used to build the graph."""
  49. logger.info("Start to build graph")
  50. start_time = time.time()
  51. # Notice:
  52. # The following methods are interdependent and cannot be switched at will.
  53. self._parse_data(proto_data)
  54. self._add_variable_nodes(NodeTypeEnum.PARAMETER.value)
  55. self._build_aggregation_scope_nodes()
  56. self._process_independent_layout()
  57. self._build_name_scope_nodes()
  58. # Since const nodes are not aggregated, adding them at the end can save a lot of computation.
  59. self._add_variable_nodes(NodeTypeEnum.CONST.value)
  60. self._calc_subnode_count()
  61. self._leaf_nodes = self._get_leaf_nodes()
  62. self._full_name_map_name = self._get_leaf_node_full_name_map()
  63. precision = 6
  64. time_consuming = round(time.time() - start_time, precision)
  65. logger.info("Build graph end, all node count: %s, const count: %s, parameter count: %s, time-consuming: %s s.",
  66. self.normal_node_count, len(self._const_node_temp_cache),
  67. len(self._parameter_node_temp_cache), time_consuming)
  68. def _get_leaf_nodes(self):
  69. """
  70. Get all leaf nodes, including normal leaf nodes, const nodes and param nodes.
  71. """
  72. leaf_nodes = {}
  73. for node_name, node in self._normal_node_map.items():
  74. # update full name
  75. if not node.full_name:
  76. node.full_name = node.name
  77. if not node.type or node.type.endswith('_scope'):
  78. continue
  79. leaf_nodes[node_name] = node
  80. return leaf_nodes
  81. def _get_leaf_node_full_name_map(self):
  82. """Get node by debugger name."""
  83. full_name_map = {}
  84. for name, node in self._leaf_nodes.items():
  85. if not node.full_name:
  86. logger.warning("Node %s does not have full name.", name)
  87. continue
  88. full_name_map[node.full_name] = name
  89. return full_name_map
  90. def exist_node(self, name):
  91. """
  92. Check node exist in graph.
  93. Args:
  94. name (str): The node name.
  95. Returns:
  96. bool, if node exists, will return True.
  97. """
  98. if name is None:
  99. return False
  100. return self._is_node_exist(node_name=name)
  101. def list_node_by_scope(self, scope=None):
  102. """
  103. List nodes by the scope of nodes. The scope of a node is the same as its parent node name.
  104. Args:
  105. scope (str): A scope of nodes.
  106. Returns:
  107. list[dict], a list object contain `Node` object.
  108. """
  109. scope = "" if scope is None else scope
  110. nodes = []
  111. for node in self._normal_node_map.values():
  112. if node.scope == scope:
  113. nodes.append(node.to_dict())
  114. return nodes
  115. def search_node_names(self, content, offset, limit):
  116. """
  117. Search node names by content.
  118. Args:
  119. content (Union[str, None]): This content can be the key content of the node to search,
  120. if None, will get all node names.
  121. offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
  122. limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
  123. Returns:
  124. list[str], a list of node names.
  125. """
  126. if content is not None:
  127. content = content.lower()
  128. catch_names = [name for name in self._normal_node_map if content in name.lower()]
  129. else:
  130. catch_names = list(self._normal_node_map)
  131. catch_names = sorted(catch_names)
  132. real_offset = offset * limit
  133. return catch_names[real_offset:real_offset+limit]
  134. def search_single_node(self, node_name):
  135. """
  136. Search node, and return every layer nodes until this node.
  137. Args:
  138. node_name (str): The name of node.
  139. Returns:
  140. dict, a dict object, format is :
  141. item_object = {'nodes': [<Node object>],
  142. 'scope_name': '<Node scope>',
  143. 'children': {<item_object>}}
  144. """
  145. if node_name and not self.exist_node(name=node_name):
  146. raise NodeNotInGraphError(node_name=node_name)
  147. response = {}
  148. nodes = self.list_node_by_scope()
  149. response.update({
  150. 'nodes': nodes,
  151. 'scope_name': '',
  152. 'children': {}
  153. })
  154. children = response['children']
  155. index = node_name.find('/')
  156. while index != -1:
  157. scope = node_name[:index]
  158. nodes = self.list_node_by_scope(scope)
  159. children.update({
  160. 'nodes': nodes,
  161. 'scope_name': scope,
  162. 'children': {}
  163. })
  164. children = children['children']
  165. index = node_name.find('/', index+1)
  166. return response
  167. def _parse_data(self, proto_data):
  168. """
  169. This method will parse the data and create basic nodes to store in the cache.
  170. The graph is then built based on the cache.
  171. """
  172. raise NotImplementedError("Before you can build a graph, you need to parse the data.")
  173. def _build_name_scope_nodes(self):
  174. """
  175. Build name scope node by every node name.
  176. We create the name scope node by the slash('/') in the node name.
  177. For example, if a node name is "Default/add", we generate a scope named 'Default' based on slash('/') and
  178. create a name scope node named 'Default'.
  179. """
  180. logger.info("Start to build name scope nodes.")
  181. scope_node_map = {}
  182. for name, node in self._normal_node_map.items():
  183. index = name.find('/')
  184. pre_index = None
  185. while index > 0:
  186. scope = name[:index]
  187. scope_node = scope_node_map.get(scope)
  188. if scope_node is None:
  189. if self._is_node_exist(node_name=scope):
  190. exist_node = self._get_normal_node(node_name=scope)
  191. if exist_node.type == NodeTypeEnum.AGGREGATION_SCOPE.value:
  192. # This scope is aggregation scope, so we don't have to do anything.
  193. pre_index = index
  194. index = name.find('/', pre_index + 1)
  195. continue
  196. # We find a node name that conflicts with the current scope and rename the node
  197. self._update_conflict_node(conflict_name=scope)
  198. # We create a node for current scope.
  199. scope_node = Node(scope, node_id=scope)
  200. scope_node.type = NodeTypeEnum.NAME_SCOPE.value
  201. scope_node.scope = '' if pre_index is None else name[:pre_index]
  202. scope_node_map.update({scope_node.name: scope_node})
  203. # Inherit input and output from sub nodes.
  204. self._inherit_input_output_from_subnode(scope_node, subnode_list=[node])
  205. pre_index = index
  206. index = name.find('/', pre_index+1)
  207. # Cache all the scope node to normal node dict
  208. for node in scope_node_map.values():
  209. self._cache_node(node)
  210. def _update_conflict_node(self, conflict_name):
  211. conflict_node = self._get_normal_node(node_name=conflict_name)
  212. base_name = conflict_name.split('/')[-1]
  213. new_name = Node.create_node_name(scope=conflict_node.scope, base_name=f'({base_name})')
  214. self._update_node_name_of_cache(conflict_node, new_name, update_parent=True)
  215. def _inherit_input_output_from_subnode(self, parent_node, subnode_list, filtered_type=None):
  216. """
  217. Adds the input and output of all direct child nodes to the current node.
  218. Args:
  219. parent_node (Node): The nodes that inherit the input and output of the child nodes.
  220. subnode_list (list[Node]): A list of child nodes that are inherited from the input and output.
  221. filtered_type (set(str)): Filter some input and output that do not require inheritance
  222. based on the node type. Default is filter const node.
  223. Note:
  224. - Only the inputs and outputs of the external scope are inherited.
  225. - Before add_const_node method, if the input is a const,
  226. the scope of the const node is not startswith the name of parent node.
  227. So in this scenario, we need to filter the const nodes.
  228. """
  229. filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type
  230. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  231. for node in subnode_list:
  232. for item_name, item_attr in getattr(node, method).items():
  233. target_node = self._get_normal_node(node_name=item_name)
  234. if item_name.startswith(f'{parent_node.name}/'):
  235. # Own scope, ignore
  236. continue
  237. if target_node.type in filtered_type:
  238. continue
  239. getattr(parent_node, f'add_{method}')(item_name, item_attr)
  240. def _build_aggregation_scope_nodes(self):
  241. """
  242. Under the same scope, the number of nodes of the same type will be aggregated after exceeding the set threshold.
  243. Note:
  244. The threshold value refers to the `MIN_GROUP_NODE_COUNT`.
  245. """
  246. logger.info("Start to build aggregation scope nodes.")
  247. group_node_map, filtered_group_names = self._find_group_nodes()
  248. # create merge scope nodes
  249. aggregation_scope_node_map = {}
  250. for i, group_name in enumerate(filtered_group_names):
  251. slash_index = group_name.rfind('/')
  252. if slash_index != -1:
  253. scope, op_type = group_name[:slash_index], group_name[slash_index+1:]
  254. else:
  255. scope, op_type = '', group_name
  256. count = len(group_node_map.get(group_name))
  257. aggregation_node_name = Node.create_node_name(scope=scope, base_name=f'{op_type}[{count}]_{i}')
  258. aggregation_scope_node = Node(name=aggregation_node_name, node_id=aggregation_node_name)
  259. aggregation_scope_node.subnode_count = count
  260. aggregation_scope_node.scope = scope
  261. aggregation_scope_node.type = NodeTypeEnum.AGGREGATION_SCOPE.value
  262. # Update the name and scope of all children nodes
  263. for node in group_node_map[group_name]:
  264. base_name = node.name.split('/')[-1]
  265. new_name = Node.create_node_name(scope=aggregation_node_name, base_name=base_name)
  266. node.scope = aggregation_node_name
  267. # Since the name scope has not been created, there is no need to update the parent node.
  268. self._update_node_name_of_cache(node, new_name, update_parent=False)
  269. # Cache this node
  270. self._cache_node(aggregation_scope_node)
  271. aggregation_scope_node_map.update({group_name: aggregation_scope_node})
  272. # Adds the input and output of all direct child nodes to the current node.
  273. for group_name, node in aggregation_scope_node_map.items():
  274. self._inherit_input_output_from_subnode(node, group_node_map[group_name])
  275. def _find_group_nodes(self):
  276. """
  277. Find nodes that can be grouped into a group.
  278. For direct child nodes in a scope, we divide them into multiple groups by node type.
  279. However, we will exclude several types of child nodes,
  280. because these types of nodes are not operational nodes.
  281. """
  282. exclude_types = {
  283. NodeTypeEnum.CONST.value,
  284. NodeTypeEnum.NAME_SCOPE.value,
  285. }
  286. group_node_map = defaultdict(list)
  287. for node in self._normal_node_map.values():
  288. if node.type in exclude_types:
  289. continue
  290. group_name = Node.create_node_name(scope=node.scope, base_name=node.type)
  291. group_node_map[group_name].append(node)
  292. # filter can group scope.
  293. filtered_group_names = []
  294. for name, nodes in group_node_map.items():
  295. if len(nodes) < self.MIN_GROUP_NODE_COUNT:
  296. continue
  297. filtered_group_names.append(name)
  298. return group_node_map, filtered_group_names
  299. def _add_variable_nodes(self, node_type):
  300. """
  301. We create the Const nodes or Parameter nodes in this method.
  302. Args:
  303. node_type (str): Decide which type of node to add.
  304. Optional is `NodeTypeEnum.CONST.value` and `NodeTypeEnum.PARAMETER.value`.
  305. Note:
  306. This method relies on the presence of data in the const cache or parameter cache.
  307. """
  308. logger.info("Start to add %s nodes to each scope in graph.", node_type)
  309. node_map = {}
  310. for node in self._normal_node_map.values():
  311. for src_name, input_attr in dict(node.input).items():
  312. if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name):
  313. continue
  314. if node_type == NodeTypeEnum.PARAMETER.value and not self._parameter_node_temp_cache.get(src_name):
  315. continue
  316. variable_name = Node.create_node_name(scope=node.scope, base_name=src_name)
  317. if node_map.get(variable_name):
  318. # There is no need to create the node repeatedly
  319. variable_node = node_map.get(variable_name)
  320. else:
  321. cache_node = self._get_normal_node(node_name=src_name)
  322. variable_node = Node(name=variable_name, node_id=variable_name)
  323. Node.copy_node_without_input_output(cache_node, variable_node)
  324. variable_node.scope = node.scope
  325. variable_node.add_output(dst_name=node.name, output_attr=input_attr)
  326. node_map.update({variable_name: variable_node})
  327. node.delete_input(src_name)
  328. node.add_input(variable_name, input_attr)
  329. for node in node_map.values():
  330. self._cache_node(node)
  331. # Remove nodes that are not used in the cache.
  332. if node_type == NodeTypeEnum.CONST.value:
  333. unused_names = set(self._const_node_temp_cache) - set(node_map)
  334. elif node_type == NodeTypeEnum.PARAMETER.value:
  335. unused_names = set(self._parameter_node_temp_cache) - set(node_map)
  336. else:
  337. raise ParamValueError("The node type should be const or parameter.")
  338. self._delete_nodes_of_cache(unused_names)
  339. def _calc_subnode_count(self):
  340. """Calc all the direct sub node count."""
  341. subnode_count_map = defaultdict(int)
  342. for node in self._normal_node_map.values():
  343. if not node.scope:
  344. continue
  345. if not self._is_node_exist(node_name=node.scope):
  346. logger.warning("Can not find a scope node by the given name(%s), "
  347. "the name scope nodes may not have been created.", node.scope)
  348. continue
  349. subnode_count_map[node.scope] = subnode_count_map[node.scope] + 1
  350. for name, count in subnode_count_map.items():
  351. node = self._get_normal_node(node_name=name)
  352. node.subnode_count = count
  353. def _get_normal_node(self, node_id=None, node_name=None):
  354. """Query node by node id or node name."""
  355. if node_id is not None:
  356. name = self._node_id_map_name.get(node_id)
  357. node = self._normal_node_map.get(name)
  358. return node
  359. if node_name is not None:
  360. return self._normal_node_map.get(node_name)
  361. raise ParamMissError('Method requires an argument that is not None.')
  362. def _is_node_exist(self, node_id=None, node_name=None):
  363. """Check node is exist."""
  364. if node_id is not None:
  365. return bool(self._node_id_map_name.get(node_id))
  366. if node_name is not None:
  367. return bool(self._normal_node_map.get(node_name))
  368. raise ParamMissError('Method requires an argument that is not None.')
  369. @property
  370. def normal_node_count(self):
  371. """Get the normal node count."""
  372. return len(self._normal_node_map)
  373. def _cache_node(self, node):
  374. """Store the node in the cache."""
  375. # Notice:
  376. # The additional caching of Const and Parameter is to handle the Const and Parameter nodes separately later.
  377. if node.type == NodeTypeEnum.CONST.value:
  378. self._const_node_temp_cache.update({node.name: node})
  379. if node.type == NodeTypeEnum.PARAMETER.value:
  380. self._parameter_node_temp_cache.update({node.name: node})
  381. self._normal_node_map.update({node.name: node})
  382. self._node_id_map_name.update({node.node_id: node.name})
  383. def _delete_nodes_of_cache(self, node_names):
  384. """Delete node from cache."""
  385. logger.debug("These nodes will be removed from the cache, node names: %s.", str(node_names))
  386. for name in node_names:
  387. if self._parameter_node_temp_cache.get(name):
  388. self._parameter_node_temp_cache.pop(name)
  389. if self._const_node_temp_cache.get(name):
  390. self._const_node_temp_cache.pop(name)
  391. node = self._get_normal_node(node_name=name)
  392. self._normal_node_map.pop(name)
  393. self._node_id_map_name.pop(node.node_id)
  394. def _update_node_name_of_cache(self, node, new_name, update_parent=False):
  395. """
  396. Update a node name which is stored in cache.
  397. Args:
  398. node (Node): The node that will be renamed.
  399. new_name (str): The new name.
  400. update_parent (bool): Determines whether the input and output of the parent node need to be updated.
  401. """
  402. logger.debug('Update node name of cache, node(%s), new name is %s.', str(node), new_name)
  403. origin_name = node.name
  404. node.name = new_name
  405. # Find all nodes that need to modify the input and input
  406. update_node_map = {}
  407. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  408. for target_name in getattr(node, method):
  409. target_node = self._get_normal_node(node_name=target_name)
  410. if target_node is None:
  411. message = f"Node should not be None, name: {target_name}, {method}: {list(getattr(node, method))}."
  412. logger.error(message)
  413. continue
  414. update_node_map.update({target_name: target_node})
  415. if not update_parent:
  416. continue
  417. slash_index = target_name.find('/')
  418. while slash_index != -1:
  419. scope_name = target_name[:slash_index]
  420. slash_index = target_name.find('/', slash_index+1)
  421. if update_node_map.get(scope_name):
  422. continue
  423. scope_node = self._get_normal_node(node_name=scope_name)
  424. if scope_node is None:
  425. message = f"Can not find the scope node by scope name({scope_name}), " \
  426. f"may be this scope node has not been built."
  427. logger.debug(message)
  428. continue
  429. update_node_map.update({scope_name: scope_node})
  430. # Update the input and output of the nodes
  431. for target_node in update_node_map.values():
  432. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  433. attr_temp = getattr(target_node, method).get(origin_name)
  434. if attr_temp is None:
  435. # This method does not have this node, so it is skipped
  436. continue
  437. # Delete the old attribute and update new name to source node or destination node.
  438. getattr(target_node, f'delete_{method}')(origin_name)
  439. getattr(target_node, f'add_{method}')(new_name, attr_temp)
  440. # Delete the origin node in cache.
  441. self._delete_nodes_of_cache(node_names=[origin_name])
  442. self._cache_node(node)
  443. def _process_independent_layout(self):
  444. """Handle separate layout nodes."""
  445. independent_layout_node_map = {}
  446. for node in self._normal_node_map.values():
  447. base_name = node.name.split('/')[-1]
  448. if node.type == NodeTypeEnum.AGGREGATION_SCOPE.value and NodeTypeEnum.PARAMETER.value in base_name:
  449. independent_layout_node_map[node.name] = node
  450. # Find all sub nodes
  451. subnode_map = defaultdict(list)
  452. for node in self._normal_node_map.values():
  453. if independent_layout_node_map.get(node.scope):
  454. subnode_map[node.scope].append(node)
  455. # Notice:
  456. # The following processing is only done for the parameter node, other types of nodes are not processed.
  457. # Later, when you need to extend to other nodes, the code needs to be adjusted.
  458. for scope_node in independent_layout_node_map.values():
  459. scope_node.independent_layout = True
  460. method = 'output'
  461. for target_name, target_attr in dict(getattr(scope_node, method)).items():
  462. proxy_attr = dict(edge_type=target_attr['edge_type'])
  463. target_node = self._get_normal_node(node_name=target_name)
  464. getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  465. # Note:
  466. # If the source node and the destination node are not in the same scope,
  467. # the proxy node is presented as scope in order to simplify the flow of the display data.
  468. # For example, the data flow is parameter[5]_1 -> add[5]_1/add1
  469. # we create a scope proxy node(add[5]_1) for parameter[5]_1,
  470. # so there is a proxy data flow parameter[5]_1 -> add[5]_1 instead of parameter[5]_1 -> add[5]_1/add1.
  471. if target_node.scope == scope_node.scope:
  472. getattr(scope_node, f'add_proxy_{method}')(target_name, proxy_attr)
  473. else:
  474. target_scope_node = self._get_normal_node(node_name=target_node.scope)
  475. getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  476. getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  477. for subnode in subnode_map[scope_node.name]:
  478. subnode.independent_layout = True
  479. for target_name, target_attr in dict(getattr(subnode, method)).items():
  480. proxy_attr = dict(edge_type=target_attr['edge_type'])
  481. target_node = self._get_normal_node(node_name=target_name)
  482. if target_node.scope == scope_node.scope:
  483. getattr(subnode, f'add_proxy_{method}')(target_name, proxy_attr)
  484. else:
  485. getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  486. input_attr = getattr(target_node, 'input')[subnode.name]
  487. input_attr['independent_layout'] = True
  488. target_node.add_input(subnode.name, input_attr)