You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the basic graph.
  17. """
  18. import copy
  19. import time
  20. from enum import Enum
  21. from mindinsight.datavisual.common.log import logger
  22. from mindinsight.datavisual.common import exceptions
  23. from .node import NodeTypeEnum
  24. from .node import Node
  25. class EdgeTypeEnum(Enum):
  26. """Node edge type enum."""
  27. CONTROL = 'control'
  28. DATA = 'data'
  29. class DataTypeEnum(Enum):
  30. """Data type enum."""
  31. DT_TENSOR = 13
  32. class Graph:
  33. """The `Graph` object is used to describe a graph file."""
  34. MIN_POLYMERIC_NODE_COUNT = 5
  35. def __init__(self):
  36. # Store nodes contain leaf nodes, name scope node, except polymeric nodes
  37. self._normal_nodes = {}
  38. # Store polymeric nodes.
  39. self._polymeric_nodes = {}
  40. # Store all nodes resolved from the file.
  41. self._leaf_nodes = {}
  42. # The format of node groups is {'group_name': {'node_name': <Node>}}
  43. self._node_groups = {}
  44. def exist_node(self, name):
  45. """
  46. Check node exist in graph.
  47. Args:
  48. name (str): The node name.
  49. Returns:
  50. bool, if node is exist will return True.
  51. """
  52. if self._normal_nodes.get(name) is None:
  53. return False
  54. return True
  55. def get_normal_nodes(self, namescope=None):
  56. """
  57. Get nodes by namescope.
  58. Args:
  59. namescope (str): A namescope of nodes.
  60. Returns:
  61. list[dict], a list object contain `Node` object.
  62. """
  63. nodes = []
  64. if namescope is None:
  65. for name, node in self._normal_nodes.items():
  66. if '/' not in name:
  67. # Get first layer nodes
  68. nodes.append(node.to_dict())
  69. return nodes
  70. namescope = namescope + '/'
  71. for name, node in self._normal_nodes.items():
  72. if name.startswith(namescope) and '/' not in name.split(namescope)[1]:
  73. nodes.append(node.to_dict())
  74. return nodes
  75. def get_polymeric_nodes(self, polymeric_scope):
  76. """
  77. Get polymeric nodes by polymeric scope.
  78. Args:
  79. polymeric_scope (str): The polymeric scope name of nodes.
  80. Returns:
  81. list[dict], a list object contain `Node` object.
  82. """
  83. nodes = []
  84. for node in self._polymeric_nodes.values():
  85. if node.polymeric_scope_name == polymeric_scope:
  86. nodes.append(node.to_dict())
  87. return nodes
  88. def search_node_names(self, content, offset, limit):
  89. """
  90. Search node names by content.
  91. Args:
  92. content (Union[str, None]): This content can be the key content of the node to search,
  93. if None, will get all node names.
  94. offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
  95. limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
  96. Returns:
  97. list[str], a list of node names.
  98. """
  99. all_names = []
  100. all_names.extend(list(self._normal_nodes.keys()))
  101. all_names.extend(list(self._polymeric_nodes.keys()))
  102. if content is not None:
  103. content = content.lower()
  104. catch_names = [name for name in all_names if content in name.lower()]
  105. else:
  106. catch_names = all_names
  107. catch_names = sorted(catch_names)
  108. real_offset = offset * limit
  109. return catch_names[real_offset:real_offset+limit]
  110. def search_single_node(self, node_name):
  111. """
  112. Search node, and return every layer nodes until this node.
  113. Args:
  114. node_name (str): The name of node.
  115. Returns:
  116. dict, a dict object, format is :
  117. item_object = {'nodes': [<Node object>],
  118. 'scope_name': '<Node scope>',
  119. 'children': {<item_object>}}
  120. """
  121. if node_name and self._polymeric_nodes.get(node_name) is None \
  122. and self._normal_nodes.get(node_name) is None:
  123. raise exceptions.NodeNotInGraphError()
  124. response = {}
  125. nodes = self.get_normal_nodes()
  126. response.update({
  127. 'nodes': nodes,
  128. 'scope_name': '',
  129. 'children': {}
  130. })
  131. names = node_name.split('/')
  132. children = response['children']
  133. for i in range(1, len(names)+1):
  134. if i == len(names):
  135. polymeric_node = self._polymeric_nodes.get(node_name)
  136. if polymeric_node:
  137. polymeric_scope = polymeric_node.polymeric_scope_name
  138. nodes = self.get_polymeric_nodes(polymeric_scope)
  139. children.update({'nodes': nodes,
  140. 'scope_name': polymeric_scope,
  141. 'children': {}})
  142. break
  143. name_scope = '/'.join(names[:i])
  144. nodes = self.get_normal_nodes(name_scope)
  145. children.update({
  146. 'nodes': nodes,
  147. 'scope_name': name_scope,
  148. 'children': {}
  149. })
  150. children = children['children']
  151. return response
  152. def _build_polymeric_nodes(self):
  153. """Build polymeric node."""
  154. logger.debug("Start to build polymeric nodes")
  155. self._find_polymeric_nodes()
  156. group_count_map = {}
  157. for group_name, group in self._node_groups.items():
  158. name = group_name.split('/')[-1]
  159. count = group_count_map.get(name, 0)
  160. count += 1
  161. group_count_map[name] = count
  162. polymeric_node_name = group_name + '_{}_[{}]'.format(count, len(group))
  163. polymeric_node = Node(polymeric_node_name, node_id=polymeric_node_name)
  164. polymeric_node.node_type = NodeTypeEnum.POLYMERIC_SCOPE.value
  165. polymeric_node.name_scope = '/'.join(group_name.split('/')[:-1])
  166. polymeric_node.subnode_count = len(group)
  167. for name_tmp, node_tmp in group.items():
  168. node_tmp.polymeric_scope_name = polymeric_node_name
  169. self._polymeric_nodes.update({name_tmp: node_tmp})
  170. polymeric_node.update_input(node_tmp.inputs)
  171. polymeric_node.update_output(node_tmp.outputs)
  172. self._normal_nodes.update({polymeric_node_name: polymeric_node})
  173. self._update_input_output()
  174. def _find_polymeric_nodes(self):
  175. """Find polymeric nodes from node groups."""
  176. node_groups = copy.deepcopy(self._node_groups)
  177. for group_name, group in node_groups.items():
  178. if len(group) < self.MIN_POLYMERIC_NODE_COUNT:
  179. self._normal_nodes.update(group)
  180. self._node_groups.pop(group_name)
  181. continue
  182. move_node_names = []
  183. is_move_group = False
  184. for node_name, group_node in group.items():
  185. node_list = []
  186. is_in_group = False
  187. for dst_name in group_node.outputs:
  188. node_tmp = self._leaf_nodes[dst_name]
  189. node_list.append(node_tmp)
  190. start = time.time()
  191. run_count = 0
  192. visit_nodes = {}
  193. while node_list:
  194. # Iterate to find if the output of the node in the group causes a loop
  195. # example: there is a group A, and node_a is a Node in group.
  196. # if there is a loop in node_a, like A/node_a -> B/node_b -> A/node_b
  197. # we will remove the node_a from group A.
  198. node_tmp = node_list[0]
  199. node_list = node_list[1:]
  200. visit_nodes.update({node_tmp.name: True})
  201. if node_tmp in group.values():
  202. is_in_group = True
  203. break
  204. for dst_name_tmp in node_tmp.outputs:
  205. run_count += 1
  206. node_tmp = self._leaf_nodes[dst_name_tmp]
  207. if visit_nodes.get(dst_name_tmp):
  208. continue
  209. node_list.append(node_tmp)
  210. logger.debug("Find group %s node end, is_in_group: %s, use time: %s, "
  211. "run count: %s.", group_name, is_in_group,
  212. time.time() - start, run_count)
  213. if is_in_group:
  214. move_node_names.append(node_name)
  215. if (len(group) - len(move_node_names)) < self.MIN_POLYMERIC_NODE_COUNT:
  216. is_move_group = True
  217. break
  218. if is_move_group:
  219. self._normal_nodes.update(group)
  220. self._node_groups.pop(group_name)
  221. else:
  222. for name_tmp in move_node_names:
  223. node_tmp = self._node_groups[group_name].pop(name_tmp)
  224. self._normal_nodes.update({name_tmp: node_tmp})
  225. def _update_input_output(self):
  226. """We need to update input and output attribute after build polymeric node."""
  227. for node in self._normal_nodes.values():
  228. for src_name, input_attr in node.inputs.items():
  229. if self._polymeric_nodes.get(src_name):
  230. input_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
  231. node.update_input({src_name: input_attr})
  232. for dst_name, output_attr in node.outputs.items():
  233. if self._polymeric_nodes.get(dst_name):
  234. output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
  235. node.update_output({dst_name: output_attr})
  236. for node in self._polymeric_nodes.values():
  237. for src_name, input_attr in node.inputs.items():
  238. if self._polymeric_nodes.get(src_name):
  239. input_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
  240. node.update_input({src_name: input_attr})
  241. for dst_name, output_attr in node.outputs.items():
  242. if self._polymeric_nodes.get(dst_name):
  243. output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
  244. node.update_output({dst_name: output_attr})
  245. def _update_polymeric_input_output(self):
  246. """Calc polymeric input and output after build polymeric node."""
  247. for node in self._normal_nodes.values():
  248. polymeric_input = self._calc_polymeric_attr(node, 'inputs')
  249. node.update_polymeric_input(polymeric_input)
  250. polymeric_output = self._calc_polymeric_attr(node, 'outputs')
  251. node.update_polymeric_output(polymeric_output)
  252. for name, node in self._polymeric_nodes.items():
  253. polymeric_input = {}
  254. for src_name in node.inputs:
  255. output_name = self._calc_dummy_node_name(name, src_name)
  256. polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
  257. node.update_polymeric_input(polymeric_input)
  258. polymeric_output = {}
  259. for dst_name in node.outputs:
  260. polymeric_output = {}
  261. output_name = self._calc_dummy_node_name(name, dst_name)
  262. polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
  263. node.update_polymeric_output(polymeric_output)
  264. def _calc_polymeric_attr(self, node, attr):
  265. """
  266. Calc polymeric input or polymeric output after build polymeric node.
  267. Args:
  268. node (Node): Computes the polymeric input for a given node.
  269. attr (str): The polymeric attr, optional value is `input` or `output`.
  270. Returns:
  271. dict, return polymeric input or polymeric output of the given node.
  272. """
  273. polymeric_attr = {}
  274. for node_name in getattr(node, attr):
  275. polymeric_node = self._polymeric_nodes.get(node_name)
  276. if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
  277. node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name
  278. dummy_node_name = self._calc_dummy_node_name(node.name, node_name)
  279. polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}})
  280. continue
  281. if not polymeric_node:
  282. continue
  283. if not node.name_scope and polymeric_node.name_scope:
  284. # If current node is in top-level layer, and the polymeric_node node is not in
  285. # the top-level layer, the polymeric node will not be the polymeric input
  286. # or polymeric output of current node.
  287. continue
  288. if node.name_scope == polymeric_node.name_scope \
  289. or node.name_scope.startswith(polymeric_node.name_scope + '/'):
  290. polymeric_attr.update(
  291. {polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}})
  292. return polymeric_attr
  293. def _calc_dummy_node_name(self, current_node_name, other_node_name):
  294. """
  295. Calc dummy node name.
  296. Args:
  297. current_node_name (str): The name of current node.
  298. other_node_name (str): The target dummy node name.
  299. Returns:
  300. str, the dummy node name.
  301. """
  302. name_tmp = other_node_name
  303. if self._polymeric_nodes.get(other_node_name):
  304. name_tmp = self._polymeric_nodes[other_node_name].polymeric_scope_name
  305. name_tmp_list = name_tmp.split('/')
  306. current_name_list = current_node_name.split('/')
  307. index = 0
  308. min_len = min(len(name_tmp_list), len(current_name_list))
  309. for i in range(min_len):
  310. index = i
  311. if name_tmp_list[index] != current_name_list[index]:
  312. break
  313. dummy_node_name = '/'.join(name_tmp_list[:index+1])
  314. return dummy_node_name
  315. def _build_name_scope_nodes(self):
  316. """Build name scope node by every node name."""
  317. normal_nodes = dict(self._normal_nodes)
  318. rename_node_names = {}
  319. for name, node in normal_nodes.items():
  320. name_list = name.split('/')
  321. for i in range(1, len(name_list)):
  322. name_scope = '/'.join(name_list[:i])
  323. name_scope_node = self._normal_nodes.get(name_scope)
  324. if name_scope_node is None:
  325. name_scope_node = Node(name_scope, node_id=name_scope)
  326. name_scope_node.node_type = NodeTypeEnum.NAME_SCOPE.value
  327. name_scope_node.name_scope = '/'.join(name_list[:i-1])
  328. elif name_scope_node.node_type != NodeTypeEnum.NAME_SCOPE.value:
  329. # The name of this node conflicts with namescope, so rename this node
  330. old_name = name_scope_node.name
  331. old_names = name_scope_node.name.split('/')
  332. old_names[-1] = f'({old_names[-1]})'
  333. new_name = '/'.join(old_names)
  334. name_scope_node.name = new_name
  335. self._normal_nodes.pop(old_name)
  336. self._normal_nodes.update({new_name: name_scope_node})
  337. rename_node_names.update({old_name: new_name})
  338. # create new namescope
  339. name_scope_node = Node(name_scope, node_id=name_scope)
  340. name_scope_node.node_type = NodeTypeEnum.NAME_SCOPE.value
  341. name_scope_node.name_scope = '/'.join(name_list[:i-1])
  342. # update the input and output of this to namescope node
  343. name_scope_with_slash = name_scope + '/'
  344. for src_name, input_attr in node.inputs.items():
  345. if src_name.startswith(name_scope_with_slash):
  346. continue
  347. name_scope_node.update_input({src_name: input_attr})
  348. for dst_name, output_attr in node.outputs.items():
  349. if dst_name.startswith(name_scope_with_slash):
  350. continue
  351. name_scope_node.update_output({dst_name: output_attr})
  352. self._normal_nodes.update({name_scope: name_scope_node})
  353. if rename_node_names:
  354. # If existing nodes are renamed, the inputs and outputs of all nodes need to be refreshed
  355. nodes = []
  356. nodes.extend(self._normal_nodes.values())
  357. nodes.extend(self._polymeric_nodes.values())
  358. for node in nodes:
  359. attrs = ['inputs', 'outputs', 'polymeric_inputs', 'polymeric_outputs']
  360. for item in attrs:
  361. tmp_dict = dict(getattr(node, item))
  362. for name, value in tmp_dict.items():
  363. new_name = rename_node_names.get(name, False)
  364. if new_name:
  365. getattr(node, item).pop(name)
  366. getattr(node, f'update_{item}')({new_name: value})
  367. self._calc_subnode_count()
  368. def _calc_subnode_count(self):
  369. """Calc the sub node count of scope node."""
  370. name_scope_mapping = {}
  371. for node in self._normal_nodes.values():
  372. if node.name_scope:
  373. count = name_scope_mapping.get(node.name_scope, 0)
  374. name_scope_mapping[node.name_scope] = count + 1
  375. for name_scope, count in name_scope_mapping.items():
  376. node = self._normal_nodes[name_scope]
  377. node.subnode_count = count

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。