You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from .node_tree import NodeTree
  20. from .node import Node
  21. from .node import NodeTypeEnum
  22. from .graph import Graph
  23. from .graph import EdgeTypeEnum
  24. from .graph import check_invalid_character
  25. class MSGraph(Graph):
  26. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  27. def _parse_data(self, proto_data):
  28. """
  29. The proto data is parsed and all nodes are stored in the specified structure.
  30. Args:
  31. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  32. """
  33. logger.info("Start to parse graph proto data.")
  34. self._parse_op_nodes(proto_data.node)
  35. self._parse_parameters(proto_data.parameters)
  36. self._parse_consts(proto_data.const_vals)
  37. self._update_input_after_create_node()
  38. self._update_output_after_create_node()
  39. logger.info("Parse proto data end, normal node count(only contain op node, "
  40. "parameter, const): %s.", self.normal_node_count)
  41. def _parse_op_nodes(self, node_protos):
  42. """
  43. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  44. Args:
  45. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  46. """
  47. logger.debug("Start to parse op nodes from proto.")
  48. for topological_index, node_proto in enumerate(node_protos):
  49. if not node_proto.name:
  50. logger.warning("Finding a node with an empty name will not save it.")
  51. continue
  52. if not node_proto.full_name or any(
  53. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  54. node_name = Node.create_node_name(scope=node_proto.scope,
  55. base_name=f'{node_proto.op_type}{node_proto.name}')
  56. else:
  57. node_name = node_proto.full_name
  58. # The Graphviz plug-in that the UI USES can't handle these special characters.
  59. check_invalid_character(node_name)
  60. node = Node(name=node_name, node_id=node_proto.name, topological_index=topological_index)
  61. node.full_name = node_proto.full_name
  62. node.type = node_proto.op_type
  63. self._parse_attributes(node_proto.attribute, node)
  64. self._parse_inputs(node_proto.input, node)
  65. node.output_i = node_proto.output_i
  66. node.scope = node_proto.scope
  67. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  68. node.output_nums = len(node.output_shape)
  69. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type, node)
  70. self._cache_node(node)
  71. def _parse_parameters(self, parameter_protos):
  72. """
  73. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  74. Args:
  75. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  76. """
  77. logger.debug("Start to parse parameters from proto.")
  78. for parameter in parameter_protos:
  79. if not parameter.name:
  80. logger.warning("Finding a parameter with an empty name will not save it.")
  81. continue
  82. check_invalid_character(parameter.name)
  83. node = Node(name=parameter.name, node_id=parameter.name)
  84. node.type = NodeTypeEnum.PARAMETER.value
  85. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  86. node.output_nums = len(node.output_shape)
  87. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type, node)
  88. attr = dict(
  89. type=self._get_data_type_by_parse_type_proto(parameter.type, node),
  90. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  91. )
  92. node.add_attr(attr)
  93. self._cache_node(node)
  94. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  95. "node def name: %s", node.node_id, node.name, parameter.name)
  96. def _parse_consts(self, consts):
  97. """
  98. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  99. Args:
  100. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  101. """
  102. logger.debug("Start to parse consts from proto.")
  103. for const in consts:
  104. if not const.key:
  105. logger.warning("Finding a const with an empty key will not save it.")
  106. continue
  107. check_invalid_character(const.key)
  108. node = Node(name=const.key, node_id=const.key)
  109. node.type = NodeTypeEnum.CONST.value
  110. if const.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  111. node.add_attr({const.key: 'dtype: ' + DataType.Name(const.value.dtype)})
  112. else:
  113. node.add_attr({const.key: str(const.value)})
  114. if const.value.dtype == DataType.DT_TENSOR:
  115. shape = list(const.value.tensor_val.dims)
  116. node.output_shape.append(shape)
  117. if const.value.tensor_val.HasField('data_type'):
  118. node.elem_types.append(DataType.Name(const.value.tensor_val.data_type))
  119. else:
  120. node.elem_types.append(DataType.Name(const.value.dtype))
  121. # dim is zero
  122. node.output_shape.append([])
  123. node.output_nums = len(node.output_shape)
  124. self._cache_node(node)
  125. def _get_shape_by_parse_type_proto(self, type_proto):
  126. """
  127. Parse proto's `message TypeProto` to get shape information.
  128. Args:
  129. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  130. Returns:
  131. list, a list of shape.
  132. """
  133. shapes = []
  134. if type_proto.HasField('data_type'):
  135. if type_proto.data_type != DataType.DT_TENSOR and \
  136. type_proto.data_type != DataType.DT_TUPLE:
  137. # Append an empty list as a placeholder
  138. # for the convenience of output number calculation.
  139. shapes.append([])
  140. return shapes
  141. if type_proto.HasField('tensor_type'):
  142. tensor_type = type_proto.tensor_type
  143. tensor_shape_proto = tensor_type.shape
  144. shape = [dim.size for dim in tensor_shape_proto.dim]
  145. shapes.append(shape)
  146. if type_proto.HasField('sequence_type'):
  147. for elem_type in type_proto.sequence_type.elem_types:
  148. shapes.extend(self._get_shape_by_parse_type_proto(elem_type))
  149. return shapes
  150. def _get_data_type_by_parse_type_proto(self, type_proto, node):
  151. """
  152. Get data type by parse type proto object.
  153. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  154. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  155. Args:
  156. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  157. Returns:
  158. str, the data type.
  159. """
  160. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  161. if type_proto.data_type == DataType.DT_TENSOR:
  162. tensor_type_proto = type_proto.tensor_type
  163. value = type_proto.tensor_type.elem_type
  164. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  165. node.elem_types.append(elem_type_name)
  166. return f'{data_type_name}[{elem_type_name}]'
  167. if type_proto.data_type == DataType.DT_TUPLE:
  168. data_types = []
  169. for elem_type in type_proto.sequence_type.elem_types:
  170. data_types.append(self._get_data_type_by_parse_type_proto(elem_type, node))
  171. return f'{data_type_name}{str(data_types)}'
  172. node.elem_types.append(data_type_name)
  173. return data_type_name
  174. def get_nodes(self, searched_node_list):
  175. """
  176. Get node tree by a searched_node_list.
  177. Args:
  178. searched_node_list (list[Node]): A list of nodes that
  179. matches the given search pattern.
  180. Returns:
  181. A list of dict including the searched nodes.
  182. [{
  183. "name": "Default",
  184. "type": "name_scope",
  185. "nodes": [{
  186. "name": "Default/Conv2D1",
  187. "type": "name_scope",
  188. "nodes": [{
  189. ...
  190. }]
  191. }]
  192. },
  193. {
  194. "name": "Gradients",
  195. "type": "name_scope",
  196. "nodes": [{
  197. "name": "Gradients/Default",
  198. "type": "name_scope",
  199. "nodes": [{
  200. ...
  201. }]
  202. }]
  203. """
  204. # save the node in the NodeTree
  205. root = NodeTree()
  206. for node in searched_node_list:
  207. self._build_node_tree(root, node.name, node.type)
  208. # get the searched nodes in the NodeTree and reorganize them
  209. searched_list = []
  210. self._traverse_node_tree(root, searched_list)
  211. return searched_list
  212. def search_leaf_nodes_by_pattern(self, pattern):
  213. """
  214. Search leaf node by a given pattern.
  215. Args:
  216. pattern (Union[str, None]): The pattern of the node to search,
  217. if None, return all node names.
  218. Returns:
  219. list[Node], a list of nodes.
  220. """
  221. if pattern is not None:
  222. pattern = pattern.lower()
  223. searched_nodes = [
  224. node for name, node in self._leaf_nodes.items()
  225. if pattern in name.lower()
  226. ]
  227. else:
  228. searched_nodes = [node for node in self._leaf_nodes.values()]
  229. return searched_nodes
  230. def search_nodes_by_pattern(self, pattern):
  231. """
  232. Search node by a given pattern.
  233. Search node which pattern is the part of the last node. Example: pattern=ops, node1=default/ops,
  234. node2=default/ops/weight, so node2 will be ignore and only node1 will be return.
  235. Args:
  236. pattern (Union[str, None]): The pattern of the node to search.
  237. Returns:
  238. list[Node], a list of nodes.
  239. """
  240. searched_nodes = []
  241. if pattern and pattern != '/':
  242. pattern = pattern.lower()
  243. for name, node in self._normal_node_map.items():
  244. name = name.lower()
  245. pattern_index = name.rfind(pattern)
  246. if pattern_index >= 0 and name.find('/', pattern_index + len(pattern)) == -1:
  247. searched_nodes.append(node)
  248. return searched_nodes
  249. def _build_node_tree(self, root, node_name, node_type):
  250. """
  251. Build node tree.
  252. Args:
  253. root (NodeTree): Root node of node tree.
  254. node_name (str): Node name.
  255. node_type (str): Node type.
  256. """
  257. scope_names = node_name.split('/')
  258. cur_node = root
  259. full_name = ""
  260. for scope_name in scope_names[:-1]:
  261. full_name = '/'.join([full_name, scope_name]) if full_name else scope_name
  262. scope_node = self._get_normal_node(node_name=full_name)
  263. sub_node = cur_node.get(scope_name)
  264. if not sub_node:
  265. sub_node = cur_node.add(scope_name, scope_node.type)
  266. cur_node = sub_node
  267. cur_node.add(scope_names[-1], node_type)
  268. def _traverse_node_tree(self, cur_node, search_node_list):
  269. """Traverse the node tree and construct the searched nodes list."""
  270. if not cur_node.get_children():
  271. return
  272. for _, sub_node in cur_node.get_children():
  273. sub_nodes = []
  274. self._traverse_node_tree(sub_node, sub_nodes)
  275. sub_node_dict = {
  276. 'name': sub_node.node_name,
  277. 'type': sub_node.node_type,
  278. 'nodes': sub_nodes
  279. }
  280. search_node_list.append(sub_node_dict)
  281. def _parse_inputs(self, input_protos, node):
  282. """
  283. Parse `anf_ir_pb2.InputProto` object.
  284. Args:
  285. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  286. node (Node): Refer to `Node` object, it is used to log message and update input.
  287. """
  288. for input_proto in input_protos:
  289. if not input_proto.name:
  290. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  291. continue
  292. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  293. # Notice:
  294. # 1. The name in the input proto is the node id of the Node object.
  295. # 2. In the current step, the shape of source node cannot be obtained,
  296. # so it is set to empty list by default, and the next step will update it.
  297. # 3. Same with scope, set the default value first.
  298. input_attr = {
  299. "shape": [],
  300. "edge_type": edge_type,
  301. "independent_layout": False,
  302. 'data_type': ''
  303. }
  304. node.add_inputs(src_name=input_proto.name, input_attr=input_attr)
  305. def _parse_attributes(self, attributes, node):
  306. """
  307. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  308. Args:
  309. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  310. node (Node): Refer to `Node` object, it is used to log message and update attr.
  311. """
  312. for attr in attributes:
  313. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  314. message = f"The attribute value of node({node.name}) " \
  315. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  316. logger.warning(message)
  317. continue
  318. node.add_attr({attr.name: str(attr.value)})
  319. def _update_input_after_create_node(self):
  320. """Update the input of node after create node."""
  321. for node in self._normal_node_map.values():
  322. for src_node_id, input_attr in dict(node.inputs).items():
  323. node.delete_inputs(src_node_id)
  324. if not self._is_node_exist(node_id=src_node_id):
  325. message = f"The input node could not be found by node id({src_node_id}) " \
  326. f"while updating the input of the node({node})"
  327. logger.warning(message)
  328. continue
  329. src_node = self._get_normal_node(node_id=src_node_id)
  330. input_attr['shape'] = src_node.output_shape
  331. input_attr['data_type'] = src_node.output_data_type
  332. node.add_inputs(src_name=src_node.name, input_attr=input_attr)
  333. def _update_output_after_create_node(self):
  334. """Update the output of node after create node."""
  335. # Constants and parameter should not exist for input and output.
  336. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  337. for node in self._normal_node_map.values():
  338. for src_name, input_attr in node.inputs.items():
  339. src_node = self._get_normal_node(node_name=src_name)
  340. if src_node.type in filtered_node:
  341. continue
  342. src_node.add_outputs(node.name, input_attr)
  343. @staticmethod
  344. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  345. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  346. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name