You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the basic graph.
  17. """
  18. from enum import Enum
  19. from collections import defaultdict
  20. from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
  21. from mindinsight.datavisual.common.log import logger
  22. from mindinsight.utils.exceptions import ParamMissError
  23. from mindinsight.utils.exceptions import ParamValueError
  24. from .node import NodeTypeEnum
  25. from .node import Node
  26. class EdgeTypeEnum(Enum):
  27. """Node edge type enum."""
  28. CONTROL = 'control'
  29. DATA = 'data'
  30. class Graph:
  31. """The `Graph` object is used to describe a graph file."""
  32. # Limit the size of a single attribute value per node to avoid storing too much data
  33. MAX_NODE_ATTRIBUTE_VALUE_BYTES = 1024
  34. # In the same scope, the number of children of the same type exceeds this threshold, and we will combine them.
  35. MIN_GROUP_NODE_COUNT = 5
  36. def __init__(self):
  37. # Used to cache all nodes, and the key is node name, value is `Node` object.
  38. self._normal_node_map = {}
  39. self._node_id_map_name = {}
  40. # The additional caching of Const and Parameter is to handle the Const
  41. # and Parameter nodes separately later.
  42. self._const_node_temp_cache = {}
  43. self._parameter_node_temp_cache = {}
  44. def build_graph(self, proto_data):
  45. """This method is used to build the graph."""
  46. # Notice:
  47. # The following methods are interdependent and cannot be switched at will.
  48. self._parse_data(proto_data)
  49. self._add_variable_nodes(NodeTypeEnum.PARAMETER.value)
  50. self._build_aggregation_scope_nodes()
  51. self._process_independent_layout()
  52. self._build_name_scope_nodes()
  53. # Since const nodes are not aggregated, adding them at the end can save a lot of computation.
  54. self._add_variable_nodes(NodeTypeEnum.CONST.value)
  55. self._calc_subnode_count()
  56. def exist_node(self, name):
  57. """
  58. Check node exist in graph.
  59. Args:
  60. name (str): The node name.
  61. Returns:
  62. bool, if node exists, will return True.
  63. """
  64. if name is None:
  65. return False
  66. return self._is_node_exist(node_name=name)
  67. def list_node_by_scope(self, scope=None):
  68. """
  69. List nodes by the scope of nodes. The scope of a node is the same as its parent node name.
  70. Args:
  71. scope (str): A scope of nodes.
  72. Returns:
  73. list[dict], a list object contain `Node` object.
  74. """
  75. scope = "" if scope is None else scope
  76. nodes = []
  77. for node in self._normal_node_map.values():
  78. if node.scope == scope:
  79. nodes.append(node.to_dict())
  80. return nodes
  81. def search_node_names(self, content, offset, limit):
  82. """
  83. Search node names by content.
  84. Args:
  85. content (Union[str, None]): This content can be the key content of the node to search,
  86. if None, will get all node names.
  87. offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
  88. limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
  89. Returns:
  90. list[str], a list of node names.
  91. """
  92. if content is not None:
  93. content = content.lower()
  94. catch_names = [name for name in self._normal_node_map if content in name.lower()]
  95. else:
  96. catch_names = list(self._normal_node_map)
  97. catch_names = sorted(catch_names)
  98. real_offset = offset * limit
  99. return catch_names[real_offset:real_offset+limit]
  100. def search_single_node(self, node_name):
  101. """
  102. Search node, and return every layer nodes until this node.
  103. Args:
  104. node_name (str): The name of node.
  105. Returns:
  106. dict, a dict object, format is :
  107. item_object = {'nodes': [<Node object>],
  108. 'scope_name': '<Node scope>',
  109. 'children': {<item_object>}}
  110. """
  111. if node_name and not self.exist_node(name=node_name):
  112. raise NodeNotInGraphError(node_name=node_name)
  113. response = {}
  114. nodes = self.list_node_by_scope()
  115. response.update({
  116. 'nodes': nodes,
  117. 'scope_name': '',
  118. 'children': {}
  119. })
  120. children = response['children']
  121. index = node_name.find('/')
  122. while index != -1:
  123. scope = node_name[:index]
  124. nodes = self.list_node_by_scope(scope)
  125. children.update({
  126. 'nodes': nodes,
  127. 'scope_name': scope,
  128. 'children': {}
  129. })
  130. children = children['children']
  131. index = node_name.find('/', index+1)
  132. return response
  133. def _parse_data(self, proto_data):
  134. """
  135. This method will parse the data and create basic nodes to store in the cache.
  136. The graph is then built based on the cache.
  137. """
  138. raise NotImplementedError("Before you can build a graph, you need to parse the data.")
  139. def _build_name_scope_nodes(self):
  140. """
  141. Build name scope node by every node name.
  142. We create the name scope node by the slash('/') in the node name.
  143. For example, if a node name is "Default/add", we generate a scope named 'Default' based on slash('/') and
  144. create a name scope node named 'Default'.
  145. """
  146. logger.info("Start to build name scope nodes.")
  147. scope_node_map = {}
  148. for name, node in self._normal_node_map.items():
  149. index = name.find('/')
  150. pre_index = None
  151. while index > 0:
  152. scope = name[:index]
  153. scope_node = scope_node_map.get(scope)
  154. if scope_node is None:
  155. if self._is_node_exist(node_name=scope):
  156. exist_node = self._get_normal_node(node_name=scope)
  157. if exist_node.type == NodeTypeEnum.AGGREGATION_SCOPE.value:
  158. # This scope is aggregation scope, so we don't have to do anything.
  159. pre_index = index
  160. index = name.find('/', pre_index + 1)
  161. continue
  162. # We find a node name that conflicts with the current scope and rename the node
  163. self._update_conflict_node(conflict_name=scope)
  164. # We create a node for current scope.
  165. scope_node = Node(scope, node_id=scope)
  166. scope_node.type = NodeTypeEnum.NAME_SCOPE.value
  167. scope_node.scope = '' if pre_index is None else name[:pre_index]
  168. scope_node_map.update({scope_node.name: scope_node})
  169. # Inherit input and output from sub nodes.
  170. self._inherit_input_output_from_subnode(scope_node, subnode_list=[node])
  171. pre_index = index
  172. index = name.find('/', pre_index+1)
  173. # Cache all the scope node to normal node dict
  174. for node in scope_node_map.values():
  175. self._cache_node(node)
  176. def _update_conflict_node(self, conflict_name):
  177. conflict_node = self._get_normal_node(node_name=conflict_name)
  178. base_name = conflict_name.split('/')[-1]
  179. new_name = Node.create_node_name(scope=conflict_node.scope, base_name=base_name)
  180. self._update_node_name_of_cache(conflict_node, new_name, update_parent=True)
  181. def _inherit_input_output_from_subnode(self, parent_node, subnode_list, filtered_type=None):
  182. """
  183. Adds the input and output of all direct child nodes to the current node.
  184. Args:
  185. parent_node (Node): The nodes that inherit the input and output of the child nodes.
  186. subnode_list (list[Node]): A list of child nodes that are inherited from the input and output.
  187. filtered_type (set(str)): Filter some input and output that do not require inheritance
  188. based on the node type. Default is filter const node.
  189. Note:
  190. - Only the inputs and outputs of the external scope are inherited.
  191. - Before add_const_node method, if the input is a const,
  192. the scope of the const node is not startswith the name of parent node.
  193. So in this scenario, we need to filter the const nodes.
  194. """
  195. filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type
  196. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  197. for node in subnode_list:
  198. for item_name, item_attr in getattr(node, method).items():
  199. target_node = self._get_normal_node(node_name=item_name)
  200. if item_name.startswith(f'{parent_node.name}/'):
  201. # Own scope, ignore
  202. continue
  203. if target_node.type in filtered_type:
  204. continue
  205. getattr(parent_node, f'add_{method}')(item_name, item_attr)
  206. def _build_aggregation_scope_nodes(self):
  207. """
  208. Under the same scope, the number of nodes of the same type will be aggregated after exceeding the set threshold.
  209. Note:
  210. The threshold value refers to the `MIN_GROUP_NODE_COUNT`.
  211. """
  212. logger.info("Start to build aggregation scope nodes.")
  213. group_node_map, filtered_group_names = self._find_group_nodes()
  214. # create merge scope nodes
  215. aggregation_scope_node_map = {}
  216. for i, group_name in enumerate(filtered_group_names):
  217. slash_index = group_name.rfind('/')
  218. if slash_index != -1:
  219. scope, op_type = group_name[:slash_index], group_name[slash_index+1:]
  220. else:
  221. scope, op_type = '', group_name
  222. count = len(group_node_map.get(group_name))
  223. aggregation_node_name = Node.create_node_name(scope=scope, base_name=f'{op_type}[{count}]_{i}')
  224. aggregation_scope_node = Node(name=aggregation_node_name, node_id=aggregation_node_name)
  225. aggregation_scope_node.subnode_count = count
  226. aggregation_scope_node.scope = scope
  227. aggregation_scope_node.type = NodeTypeEnum.AGGREGATION_SCOPE.value
  228. # Update the name and scope of all children nodes
  229. for node in group_node_map[group_name]:
  230. base_name = node.name.split('/')[-1]
  231. new_name = Node.create_node_name(scope=aggregation_node_name, base_name=base_name)
  232. node.scope = aggregation_node_name
  233. # Since the name scope has not been created, there is no need to update the parent node.
  234. self._update_node_name_of_cache(node, new_name, update_parent=False)
  235. # Cache this node
  236. self._cache_node(aggregation_scope_node)
  237. aggregation_scope_node_map.update({group_name: aggregation_scope_node})
  238. # Adds the input and output of all direct child nodes to the current node.
  239. for group_name, node in aggregation_scope_node_map.items():
  240. self._inherit_input_output_from_subnode(node, group_node_map[group_name])
  241. def _find_group_nodes(self):
  242. """
  243. Find nodes that can be grouped into a group.
  244. For direct child nodes in a scope, we divide them into multiple groups by node type.
  245. However, we will exclude several types of child nodes,
  246. because these types of nodes are not operational nodes.
  247. """
  248. exclude_types = {
  249. NodeTypeEnum.CONST.value,
  250. NodeTypeEnum.NAME_SCOPE.value,
  251. }
  252. group_node_map = defaultdict(list)
  253. for node in self._normal_node_map.values():
  254. if node.type in exclude_types:
  255. continue
  256. group_name = Node.create_node_name(scope=node.scope, base_name=node.type)
  257. group_node_map[group_name].append(node)
  258. # filter can group scope.
  259. filtered_group_names = []
  260. for name, nodes in group_node_map.items():
  261. if len(nodes) < self.MIN_GROUP_NODE_COUNT:
  262. continue
  263. filtered_group_names.append(name)
  264. return group_node_map, filtered_group_names
  265. def _add_variable_nodes(self, node_type):
  266. """
  267. We create the Const nodes or Parameter nodes in this method.
  268. Args:
  269. node_type (str): Decide which type of node to add.
  270. Optional is `NodeTypeEnum.CONST.value` and `NodeTypeEnum.PARAMETER.value`.
  271. Note:
  272. This method relies on the presence of data in the const cache or parameter cache.
  273. """
  274. logger.info("Start to add %s nodes to each scope in graph.", node_type)
  275. node_map = {}
  276. for node in self._normal_node_map.values():
  277. for src_name, input_attr in dict(node.input).items():
  278. if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name):
  279. continue
  280. if node_type == NodeTypeEnum.PARAMETER.value and not self._parameter_node_temp_cache.get(src_name):
  281. continue
  282. variable_name = Node.create_node_name(scope=node.scope, base_name=src_name)
  283. if node_map.get(variable_name):
  284. # There is no need to create the node repeatedly
  285. variable_node = node_map.get(variable_name)
  286. else:
  287. cache_node = self._get_normal_node(node_name=src_name)
  288. variable_node = Node(name=variable_name, node_id=variable_name)
  289. Node.copy_node_without_input_output(cache_node, variable_node)
  290. variable_node.scope = node.scope
  291. variable_node.add_output(dst_name=node.name, output_attr=input_attr)
  292. node_map.update({variable_name: variable_node})
  293. node.delete_input(src_name)
  294. node.add_input(variable_name, input_attr)
  295. for node in node_map.values():
  296. self._cache_node(node)
  297. # Remove nodes that are not used in the cache.
  298. if node_type == NodeTypeEnum.CONST.value:
  299. unused_names = set(self._const_node_temp_cache) - set(node_map)
  300. elif node_type == NodeTypeEnum.PARAMETER.value:
  301. unused_names = set(self._parameter_node_temp_cache) - set(node_map)
  302. else:
  303. raise ParamValueError("The node type should be const or parameter.")
  304. self._delete_nodes_of_cache(unused_names)
  305. def _calc_subnode_count(self):
  306. """Calc all the direct sub node count."""
  307. subnode_count_map = defaultdict(int)
  308. for node in self._normal_node_map.values():
  309. if not node.scope:
  310. continue
  311. if not self._is_node_exist(node_name=node.scope):
  312. logger.warning("Can not find a scope node by the given name(%s), "
  313. "the name scope nodes may not have been created.", node.scope)
  314. continue
  315. subnode_count_map[node.scope] = subnode_count_map[node.scope] + 1
  316. for name, count in subnode_count_map.items():
  317. node = self._get_normal_node(node_name=name)
  318. node.subnode_count = count
  319. def _get_normal_node(self, node_id=None, node_name=None):
  320. """Query node by node id or node name."""
  321. if node_id is not None:
  322. name = self._node_id_map_name.get(node_id)
  323. node = self._normal_node_map.get(name)
  324. return node
  325. if node_name is not None:
  326. return self._normal_node_map.get(node_name)
  327. raise ParamMissError('Method requires an argument that is not None.')
  328. def _is_node_exist(self, node_id=None, node_name=None):
  329. """Check node is exist."""
  330. if node_id is not None:
  331. return bool(self._node_id_map_name.get(node_id))
  332. if node_name is not None:
  333. return bool(self._normal_node_map.get(node_name))
  334. raise ParamMissError('Method requires an argument that is not None.')
  335. @property
  336. def normal_node_count(self):
  337. """Get the normal node count."""
  338. return len(self._normal_node_map)
  339. def _cache_node(self, node):
  340. """Store the node in the cache."""
  341. # Notice:
  342. # The additional caching of Const and Parameter is to handle the Const and Parameter nodes separately later.
  343. if node.type == NodeTypeEnum.CONST.value:
  344. self._const_node_temp_cache.update({node.name: node})
  345. if node.type == NodeTypeEnum.PARAMETER.value:
  346. self._parameter_node_temp_cache.update({node.name: node})
  347. self._normal_node_map.update({node.name: node})
  348. self._node_id_map_name.update({node.node_id: node.name})
  349. def _delete_nodes_of_cache(self, node_names):
  350. """Delete node from cache."""
  351. logger.debug("These nodes will be removed from the cache, node names: %s.", str(node_names))
  352. for name in node_names:
  353. if self._parameter_node_temp_cache.get(name):
  354. self._parameter_node_temp_cache.pop(name)
  355. if self._const_node_temp_cache.get(name):
  356. self._const_node_temp_cache.pop(name)
  357. node = self._get_normal_node(node_name=name)
  358. self._normal_node_map.pop(name)
  359. self._node_id_map_name.pop(node.node_id)
  360. def _update_node_name_of_cache(self, node, new_name, update_parent=False):
  361. """
  362. Update a node name which is stored in cache.
  363. Args:
  364. node (Node): The node that will be renamed.
  365. new_name (str): The new name.
  366. update_parent (bool): Determines whether the input and output of the parent node need to be updated.
  367. """
  368. logger.debug('Update node name of cache, node(%s), new name is %s.', str(node), new_name)
  369. origin_name = node.name
  370. node.name = new_name
  371. # Find all nodes that need to modify the input and input
  372. update_node_map = {}
  373. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  374. for target_name in getattr(node, method):
  375. target_node = self._get_normal_node(node_name=target_name)
  376. if target_node is None:
  377. message = f"Node should not be None, name: {target_name}, {method}: {list(getattr(node, method))}."
  378. logger.error(message)
  379. continue
  380. update_node_map.update({target_name: target_node})
  381. if not update_parent:
  382. continue
  383. slash_index = target_name.find('/')
  384. while slash_index != -1:
  385. scope_name = target_name[:slash_index]
  386. slash_index = target_name.find('/', slash_index+1)
  387. if update_node_map.get(scope_name):
  388. continue
  389. scope_node = self._get_normal_node(node_name=scope_name)
  390. update_node_map.update({scope_name: scope_node})
  391. # Update the input and output of the nodes
  392. for target_node in update_node_map.values():
  393. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  394. attr_temp = getattr(target_node, method).get(origin_name)
  395. if attr_temp is None:
  396. # This method does not have this node, so it is skipped
  397. continue
  398. # Delete the old attribute and update new name to source node or destination node.
  399. getattr(target_node, f'delete_{method}')(origin_name)
  400. getattr(target_node, f'add_{method}')(new_name, attr_temp)
  401. # Delete the origin node in cache.
  402. self._delete_nodes_of_cache(node_names=[origin_name])
  403. self._cache_node(node)
  404. def _process_independent_layout(self):
  405. """Handle separate layout nodes."""
  406. independent_layout_node_map = {}
  407. for node in self._normal_node_map.values():
  408. base_name = node.name.split('/')[-1]
  409. if node.type == NodeTypeEnum.AGGREGATION_SCOPE.value and NodeTypeEnum.PARAMETER.value in base_name:
  410. independent_layout_node_map[node.name] = node
  411. # Find all sub nodes
  412. subnode_map = defaultdict(list)
  413. for node in self._normal_node_map.values():
  414. if independent_layout_node_map.get(node.scope):
  415. subnode_map[node.scope].append(node)
  416. # Notice:
  417. # The following processing is only done for the parameter node, other types of nodes are not processed.
  418. # Later, when you need to extend to other nodes, the code needs to be adjusted.
  419. for scope_node in independent_layout_node_map.values():
  420. scope_node.independent_layout = True
  421. method = 'output'
  422. for target_name, target_attr in dict(getattr(scope_node, method)).items():
  423. proxy_attr = dict(edge_type=target_attr['edge_type'])
  424. target_node = self._get_normal_node(node_name=target_name)
  425. getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  426. # Note:
  427. # If the source node and the destination node are not in the same scope,
  428. # the proxy node is presented as scope in order to simplify the flow of the display data.
  429. # For example, the data flow is parameter[5]_1 -> add[5]_1/add1
  430. # we create a scope proxy node(add[5]_1) for parameter[5]_1,
  431. # so there is a proxy data flow parameter[5]_1 -> add[5]_1 instead of parameter[5]_1 -> add[5]_1/add1.
  432. if target_node.scope == scope_node.scope:
  433. getattr(scope_node, f'add_proxy_{method}')(target_name, proxy_attr)
  434. else:
  435. target_scope_node = self._get_normal_node(node_name=target_node.scope)
  436. getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  437. getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  438. for subnode in subnode_map[scope_node.name]:
  439. for target_name, target_attr in dict(getattr(subnode, method)).items():
  440. proxy_attr = dict(edge_type=target_attr['edge_type'])
  441. target_node = self._get_normal_node(node_name=target_name)
  442. if target_node.scope == scope_node.scope:
  443. getattr(subnode, f'add_proxy_{method}')(target_name, proxy_attr)
  444. else:
  445. getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  446. input_attr = getattr(target_node, 'input')[subnode.name]
  447. input_attr['independent_layout'] = True
  448. target_node.add_input(subnode.name, input_attr)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。