You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from mindinsight.domain.graph.base import Source
  20. from .node_tree import NodeTree
  21. from .node import Node
  22. from .node import NodeTypeEnum
  23. from .graph import Graph
  24. from .graph import EdgeTypeEnum
  25. from .graph import check_invalid_character
  26. class MSGraph(Graph):
  27. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  28. def _parse_data(self, proto_data):
  29. """
  30. The proto data is parsed and all nodes are stored in the specified structure.
  31. Args:
  32. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  33. """
  34. logger.info("Start to parse graph proto data.")
  35. self._parse_op_nodes(proto_data.node)
  36. self._parse_parameters(proto_data.parameters)
  37. self._parse_consts(proto_data.const_vals)
  38. self._update_input_after_create_node()
  39. self._update_output_after_create_node()
  40. logger.info("Parse proto data end, normal node count(only contain op node, "
  41. "parameter, const): %s.", self.normal_node_count)
  42. def _parse_op_nodes(self, node_protos):
  43. """
  44. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  45. Args:
  46. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  47. """
  48. logger.debug("Start to parse op nodes from proto.")
  49. for topological_index, node_proto in enumerate(node_protos):
  50. if not node_proto.name:
  51. logger.warning("Finding a node with an empty name will not save it.")
  52. continue
  53. if node_proto.op_type == "Load":
  54. # The Load operator needs to be renamed as it has the same name with parameter
  55. node_name = Node.create_node_name(scope=node_proto.scope,
  56. base_name=f'{node_proto.op_type}-op{node_proto.name}')
  57. node_proto.full_name = node_name
  58. elif not node_proto.full_name or any(
  59. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  60. node_name = Node.create_node_name(scope=node_proto.scope,
  61. base_name=f'{node_proto.op_type}{node_proto.name}')
  62. else:
  63. node_name = node_proto.full_name
  64. # The Graphviz plug-in that the UI USES can't handle these special characters.
  65. check_invalid_character(node_name)
  66. node = Node(name=node_name, node_id=node_proto.name, topological_index=topological_index)
  67. node.full_name = node_proto.full_name
  68. node.type = node_proto.op_type
  69. if getattr(node_proto, 'source_address', None):
  70. node.stack = Source.build_stack_from_source_address(node_proto.source_address)
  71. self._parse_attributes(node_proto.attribute, node)
  72. self._parse_inputs(node_proto.input, node)
  73. node.output_i = node_proto.output_i
  74. node.scope = node_proto.scope
  75. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  76. node.output_nums = len(node.output_shape)
  77. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type, node)
  78. self._cache_node(node)
  79. def _parse_parameters(self, parameter_protos):
  80. """
  81. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  82. Args:
  83. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  84. """
  85. logger.debug("Start to parse parameters from proto.")
  86. for parameter in parameter_protos:
  87. if not parameter.name:
  88. logger.warning("Finding a parameter with an empty name will not save it.")
  89. continue
  90. check_invalid_character(parameter.name)
  91. node = Node(name=parameter.name, node_id=parameter.name)
  92. node.type = NodeTypeEnum.PARAMETER.value
  93. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  94. node.output_nums = len(node.output_shape)
  95. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type, node)
  96. attr = dict(
  97. type=self._get_data_type_by_parse_type_proto(parameter.type, node),
  98. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  99. )
  100. node.add_attr(attr)
  101. self._cache_node(node)
  102. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  103. "node def name: %s", node.node_id, node.name, parameter.name)
  104. def _parse_consts(self, consts):
  105. """
  106. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  107. Args:
  108. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  109. """
  110. logger.debug("Start to parse consts from proto.")
  111. for const in consts:
  112. if not const.key:
  113. logger.warning("Finding a const with an empty key will not save it.")
  114. continue
  115. check_invalid_character(const.key)
  116. node = Node(name=const.key, node_id=const.key)
  117. node.type = NodeTypeEnum.CONST.value
  118. if const.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  119. node.add_attr({const.key: 'dtype: ' + DataType.Name(const.value.dtype)})
  120. else:
  121. node.add_attr({const.key: str(const.value)})
  122. if const.value.dtype == DataType.DT_TENSOR:
  123. shape = list(const.value.tensor_val.dims)
  124. node.output_shape.append(shape)
  125. if const.value.tensor_val.HasField('data_type'):
  126. node.elem_types.append(DataType.Name(const.value.tensor_val.data_type))
  127. else:
  128. node.elem_types.append(DataType.Name(const.value.dtype))
  129. # dim is zero
  130. node.output_shape.append([])
  131. node.output_nums = len(node.output_shape)
  132. self._cache_node(node)
  133. def _get_shape_by_parse_type_proto(self, type_proto):
  134. """
  135. Parse proto's `message TypeProto` to get shape information.
  136. Args:
  137. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  138. Returns:
  139. list, a list of shape.
  140. """
  141. shapes = []
  142. if type_proto.HasField('data_type'):
  143. if type_proto.data_type != DataType.DT_TENSOR and \
  144. type_proto.data_type != DataType.DT_TUPLE:
  145. # Append an empty list as a placeholder
  146. # for the convenience of output number calculation.
  147. shapes.append([])
  148. return shapes
  149. if type_proto.HasField('tensor_type'):
  150. tensor_type = type_proto.tensor_type
  151. tensor_shape_proto = tensor_type.shape
  152. shape = [dim.size for dim in tensor_shape_proto.dim]
  153. shapes.append(shape)
  154. if type_proto.HasField('sequence_type'):
  155. for elem_type in type_proto.sequence_type.elem_types:
  156. shapes.extend(self._get_shape_by_parse_type_proto(elem_type))
  157. return shapes
  158. def _get_data_type_by_parse_type_proto(self, type_proto, node):
  159. """
  160. Get data type by parse type proto object.
  161. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  162. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  163. Args:
  164. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  165. Returns:
  166. str, the data type.
  167. """
  168. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  169. if type_proto.data_type == DataType.DT_TENSOR:
  170. tensor_type_proto = type_proto.tensor_type
  171. value = type_proto.tensor_type.elem_type
  172. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  173. node.elem_types.append(elem_type_name)
  174. return f'{data_type_name}[{elem_type_name}]'
  175. if type_proto.data_type == DataType.DT_TUPLE:
  176. data_types = []
  177. for elem_type in type_proto.sequence_type.elem_types:
  178. data_types.append(self._get_data_type_by_parse_type_proto(elem_type, node))
  179. return f'{data_type_name}{str(data_types)}'
  180. node.elem_types.append(data_type_name)
  181. return data_type_name
  182. def get_nodes(self, searched_node_list):
  183. """
  184. Get node tree by a searched_node_list.
  185. Args:
  186. searched_node_list (list[Node]): A list of nodes that
  187. matches the given search pattern.
  188. Returns:
  189. A list of dict including the searched nodes.
  190. [{
  191. "name": "Default",
  192. "type": "name_scope",
  193. "nodes": [{
  194. "name": "Default/Conv2D1",
  195. "type": "name_scope",
  196. "nodes": [{
  197. ...
  198. }]
  199. }]
  200. },
  201. {
  202. "name": "Gradients",
  203. "type": "name_scope",
  204. "nodes": [{
  205. "name": "Gradients/Default",
  206. "type": "name_scope",
  207. "nodes": [{
  208. ...
  209. }]
  210. }]
  211. """
  212. # save the node in the NodeTree
  213. root = NodeTree()
  214. for node in searched_node_list:
  215. self._build_node_tree(root, node.name, node.type)
  216. # get the searched nodes in the NodeTree and reorganize them
  217. searched_list = []
  218. self._traverse_node_tree(root, searched_list)
  219. return searched_list
  220. def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False):
  221. """
  222. Search leaf node by a given pattern.
  223. Args:
  224. pattern (Union[str, None]): The pattern of the node to search,
  225. if None, return all node names.
  226. scope_pattern (bool): If true, return the children nodes of the scope. Default: False.
  227. Returns:
  228. list[Node], a list of nodes.
  229. """
  230. is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower()
  231. if pattern is not None:
  232. pattern = pattern.lower()
  233. searched_nodes = [
  234. node for name, node in self._leaf_nodes.items()
  235. if is_match(name, pattern)
  236. ]
  237. else:
  238. searched_nodes = [node for node in self._leaf_nodes.values()]
  239. return searched_nodes
  240. def search_nodes_by_pattern(self, pattern):
  241. """
  242. Search node by a given pattern.
  243. Search node which pattern is the part of the last node. Example: pattern=ops, node1=default/ops,
  244. node2=default/ops/weight, so node2 will be ignore and only node1 will be return.
  245. Args:
  246. pattern (Union[str, None]): The pattern of the node to search.
  247. Returns:
  248. list[Node], a list of nodes.
  249. """
  250. searched_nodes = []
  251. if pattern and pattern != '/':
  252. pattern = pattern.lower()
  253. for name, node in self._normal_node_map.items():
  254. name = name.lower()
  255. pattern_index = name.rfind(pattern)
  256. if pattern_index >= 0 and name.find('/', pattern_index + len(pattern)) == -1:
  257. searched_nodes.append(node)
  258. return searched_nodes
  259. def _build_node_tree(self, root, node_name, node_type):
  260. """
  261. Build node tree.
  262. Args:
  263. root (NodeTree): Root node of node tree.
  264. node_name (str): Node name.
  265. node_type (str): Node type.
  266. """
  267. scope_names = node_name.split('/')
  268. cur_node = root
  269. full_name = ""
  270. for scope_name in scope_names[:-1]:
  271. full_name = '/'.join([full_name, scope_name]) if full_name else scope_name
  272. scope_node = self._get_normal_node(node_name=full_name)
  273. sub_node = cur_node.get(scope_name)
  274. if not sub_node:
  275. sub_node = cur_node.add(scope_name, scope_node.type)
  276. cur_node = sub_node
  277. cur_node.add(scope_names[-1], node_type)
  278. def _traverse_node_tree(self, cur_node, search_node_list):
  279. """Traverse the node tree and construct the searched nodes list."""
  280. for _, sub_node in cur_node.get_children():
  281. sub_nodes = []
  282. self._traverse_node_tree(sub_node, sub_nodes)
  283. sub_node_dict = {
  284. 'name': sub_node.node_name,
  285. 'type': sub_node.node_type,
  286. 'nodes': sub_nodes
  287. }
  288. search_node_list.append(sub_node_dict)
  289. def _parse_inputs(self, input_protos, node):
  290. """
  291. Parse `anf_ir_pb2.InputProto` object.
  292. Args:
  293. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  294. node (Node): Refer to `Node` object, it is used to log message and update input.
  295. """
  296. for input_proto in input_protos:
  297. if not input_proto.name:
  298. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  299. continue
  300. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  301. # Notice:
  302. # 1. The name in the input proto is the node id of the Node object.
  303. # 2. In the current step, the shape of source node cannot be obtained,
  304. # so it is set to empty list by default, and the next step will update it.
  305. # 3. Same with scope, set the default value first.
  306. input_attr = {
  307. "shape": [],
  308. "edge_type": edge_type,
  309. "independent_layout": False,
  310. 'data_type': ''
  311. }
  312. node.add_inputs(src_name=input_proto.name, input_attr=input_attr)
  313. def _parse_attributes(self, attributes, node):
  314. """
  315. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  316. Args:
  317. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  318. node (Node): Refer to `Node` object, it is used to log message and update attr.
  319. """
  320. for attr in attributes:
  321. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  322. message = f"The attribute value of node({node.name}) " \
  323. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  324. logger.warning(message)
  325. continue
  326. node.add_attr({attr.name: str(attr.value)})
  327. def _update_input_after_create_node(self):
  328. """Update the input of node after create node."""
  329. for node in self._normal_node_map.values():
  330. for src_node_id, input_attr in dict(node.inputs).items():
  331. node.delete_inputs(src_node_id)
  332. if not self._is_node_exist(node_id=src_node_id):
  333. message = f"The input node could not be found by node id({src_node_id}) " \
  334. f"while updating the input of the node({node})"
  335. logger.warning(message)
  336. continue
  337. src_node = self._get_normal_node(node_id=src_node_id)
  338. input_attr['shape'] = src_node.output_shape
  339. input_attr['data_type'] = src_node.output_data_type
  340. node.add_inputs(src_name=src_node.name, input_attr=input_attr)
  341. def _update_output_after_create_node(self):
  342. """Update the output of node after create node."""
  343. # Constants and parameter should not exist for input and output.
  344. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  345. for node in self._normal_node_map.values():
  346. for src_name, input_attr in node.inputs.items():
  347. src_node = self._get_normal_node(node_name=src_name)
  348. if src_node.type in filtered_node:
  349. continue
  350. src_node.add_outputs(node.name, input_attr)
  351. @staticmethod
  352. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  353. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  354. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name