You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from .node_tree import NodeTree
  20. from .node import Node
  21. from .node import NodeTypeEnum
  22. from .graph import Graph
  23. from .graph import EdgeTypeEnum
  24. from .graph import check_invalid_character
  25. class MSGraph(Graph):
  26. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  27. def _parse_data(self, proto_data):
  28. """
  29. The proto data is parsed and all nodes are stored in the specified structure.
  30. Args:
  31. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  32. """
  33. logger.info("Start to parse graph proto data.")
  34. self._parse_op_nodes(proto_data.node)
  35. self._parse_parameters(proto_data.parameters)
  36. self._parse_consts(proto_data.const_vals)
  37. self._update_input_after_create_node()
  38. self._update_output_after_create_node()
  39. logger.info("Parse proto data end, normal node count(only contain op node, "
  40. "parameter, const): %s.", self.normal_node_count)
  41. def _parse_op_nodes(self, node_protos):
  42. """
  43. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  44. Args:
  45. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  46. """
  47. logger.debug("Start to parse op nodes from proto.")
  48. for topological_index, node_proto in enumerate(node_protos):
  49. if not node_proto.name:
  50. logger.warning("Finding a node with an empty name will not save it.")
  51. continue
  52. if node_proto.op_type == "Load":
  53. # The Load operator needs to be renamed as it has the same name with parameter
  54. node_name = Node.create_node_name(scope=node_proto.scope,
  55. base_name=f'{node_proto.op_type}-op{node_proto.name}')
  56. node_proto.full_name = node_name
  57. elif not node_proto.full_name or any(
  58. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  59. node_name = Node.create_node_name(scope=node_proto.scope,
  60. base_name=f'{node_proto.op_type}{node_proto.name}')
  61. else:
  62. node_name = node_proto.full_name
  63. # The Graphviz plug-in that the UI USES can't handle these special characters.
  64. check_invalid_character(node_name)
  65. node = Node(name=node_name, node_id=node_proto.name, topological_index=topological_index)
  66. node.full_name = node_proto.full_name
  67. node.type = node_proto.op_type
  68. self._parse_attributes(node_proto.attribute, node)
  69. self._parse_inputs(node_proto.input, node)
  70. node.output_i = node_proto.output_i
  71. node.scope = node_proto.scope
  72. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  73. node.output_nums = len(node.output_shape)
  74. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type, node)
  75. self._cache_node(node)
  76. def _parse_parameters(self, parameter_protos):
  77. """
  78. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  79. Args:
  80. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  81. """
  82. logger.debug("Start to parse parameters from proto.")
  83. for parameter in parameter_protos:
  84. if not parameter.name:
  85. logger.warning("Finding a parameter with an empty name will not save it.")
  86. continue
  87. check_invalid_character(parameter.name)
  88. node = Node(name=parameter.name, node_id=parameter.name)
  89. node.type = NodeTypeEnum.PARAMETER.value
  90. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  91. node.output_nums = len(node.output_shape)
  92. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type, node)
  93. attr = dict(
  94. type=self._get_data_type_by_parse_type_proto(parameter.type, node),
  95. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  96. )
  97. node.add_attr(attr)
  98. self._cache_node(node)
  99. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  100. "node def name: %s", node.node_id, node.name, parameter.name)
  101. def _parse_consts(self, consts):
  102. """
  103. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  104. Args:
  105. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  106. """
  107. logger.debug("Start to parse consts from proto.")
  108. for const in consts:
  109. if not const.key:
  110. logger.warning("Finding a const with an empty key will not save it.")
  111. continue
  112. check_invalid_character(const.key)
  113. node = Node(name=const.key, node_id=const.key)
  114. node.type = NodeTypeEnum.CONST.value
  115. if const.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  116. node.add_attr({const.key: 'dtype: ' + DataType.Name(const.value.dtype)})
  117. else:
  118. node.add_attr({const.key: str(const.value)})
  119. if const.value.dtype == DataType.DT_TENSOR:
  120. shape = list(const.value.tensor_val.dims)
  121. node.output_shape.append(shape)
  122. if const.value.tensor_val.HasField('data_type'):
  123. node.elem_types.append(DataType.Name(const.value.tensor_val.data_type))
  124. else:
  125. node.elem_types.append(DataType.Name(const.value.dtype))
  126. # dim is zero
  127. node.output_shape.append([])
  128. node.output_nums = len(node.output_shape)
  129. self._cache_node(node)
  130. def _get_shape_by_parse_type_proto(self, type_proto):
  131. """
  132. Parse proto's `message TypeProto` to get shape information.
  133. Args:
  134. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  135. Returns:
  136. list, a list of shape.
  137. """
  138. shapes = []
  139. if type_proto.HasField('data_type'):
  140. if type_proto.data_type != DataType.DT_TENSOR and \
  141. type_proto.data_type != DataType.DT_TUPLE:
  142. # Append an empty list as a placeholder
  143. # for the convenience of output number calculation.
  144. shapes.append([])
  145. return shapes
  146. if type_proto.HasField('tensor_type'):
  147. tensor_type = type_proto.tensor_type
  148. tensor_shape_proto = tensor_type.shape
  149. shape = [dim.size for dim in tensor_shape_proto.dim]
  150. shapes.append(shape)
  151. if type_proto.HasField('sequence_type'):
  152. for elem_type in type_proto.sequence_type.elem_types:
  153. shapes.extend(self._get_shape_by_parse_type_proto(elem_type))
  154. return shapes
  155. def _get_data_type_by_parse_type_proto(self, type_proto, node):
  156. """
  157. Get data type by parse type proto object.
  158. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  159. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  160. Args:
  161. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  162. Returns:
  163. str, the data type.
  164. """
  165. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  166. if type_proto.data_type == DataType.DT_TENSOR:
  167. tensor_type_proto = type_proto.tensor_type
  168. value = type_proto.tensor_type.elem_type
  169. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  170. node.elem_types.append(elem_type_name)
  171. return f'{data_type_name}[{elem_type_name}]'
  172. if type_proto.data_type == DataType.DT_TUPLE:
  173. data_types = []
  174. for elem_type in type_proto.sequence_type.elem_types:
  175. data_types.append(self._get_data_type_by_parse_type_proto(elem_type, node))
  176. return f'{data_type_name}{str(data_types)}'
  177. node.elem_types.append(data_type_name)
  178. return data_type_name
  179. def get_nodes(self, searched_node_list):
  180. """
  181. Get node tree by a searched_node_list.
  182. Args:
  183. searched_node_list (list[Node]): A list of nodes that
  184. matches the given search pattern.
  185. Returns:
  186. A list of dict including the searched nodes.
  187. [{
  188. "name": "Default",
  189. "type": "name_scope",
  190. "nodes": [{
  191. "name": "Default/Conv2D1",
  192. "type": "name_scope",
  193. "nodes": [{
  194. ...
  195. }]
  196. }]
  197. },
  198. {
  199. "name": "Gradients",
  200. "type": "name_scope",
  201. "nodes": [{
  202. "name": "Gradients/Default",
  203. "type": "name_scope",
  204. "nodes": [{
  205. ...
  206. }]
  207. }]
  208. """
  209. # save the node in the NodeTree
  210. root = NodeTree()
  211. for node in searched_node_list:
  212. self._build_node_tree(root, node.name, node.type)
  213. # get the searched nodes in the NodeTree and reorganize them
  214. searched_list = []
  215. self._traverse_node_tree(root, searched_list)
  216. return searched_list
  217. def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False):
  218. """
  219. Search leaf node by a given pattern.
  220. Args:
  221. pattern (Union[str, None]): The pattern of the node to search,
  222. if None, return all node names.
  223. scope_pattern (bool): If true, return the children nodes of the scope. Default: False.
  224. Returns:
  225. list[Node], a list of nodes.
  226. """
  227. is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower()
  228. if pattern is not None:
  229. pattern = pattern.lower()
  230. searched_nodes = [
  231. node for name, node in self._leaf_nodes.items()
  232. if is_match(name, pattern)
  233. ]
  234. else:
  235. searched_nodes = [node for node in self._leaf_nodes.values()]
  236. return searched_nodes
  237. def search_nodes_by_pattern(self, pattern):
  238. """
  239. Search node by a given pattern.
  240. Search node which pattern is the part of the last node. Example: pattern=ops, node1=default/ops,
  241. node2=default/ops/weight, so node2 will be ignore and only node1 will be return.
  242. Args:
  243. pattern (Union[str, None]): The pattern of the node to search.
  244. Returns:
  245. list[Node], a list of nodes.
  246. """
  247. searched_nodes = []
  248. if pattern and pattern != '/':
  249. pattern = pattern.lower()
  250. for name, node in self._normal_node_map.items():
  251. name = name.lower()
  252. pattern_index = name.rfind(pattern)
  253. if pattern_index >= 0 and name.find('/', pattern_index + len(pattern)) == -1:
  254. searched_nodes.append(node)
  255. return searched_nodes
  256. def _build_node_tree(self, root, node_name, node_type):
  257. """
  258. Build node tree.
  259. Args:
  260. root (NodeTree): Root node of node tree.
  261. node_name (str): Node name.
  262. node_type (str): Node type.
  263. """
  264. scope_names = node_name.split('/')
  265. cur_node = root
  266. full_name = ""
  267. for scope_name in scope_names[:-1]:
  268. full_name = '/'.join([full_name, scope_name]) if full_name else scope_name
  269. scope_node = self._get_normal_node(node_name=full_name)
  270. sub_node = cur_node.get(scope_name)
  271. if not sub_node:
  272. sub_node = cur_node.add(scope_name, scope_node.type)
  273. cur_node = sub_node
  274. cur_node.add(scope_names[-1], node_type)
  275. def _traverse_node_tree(self, cur_node, search_node_list):
  276. """Traverse the node tree and construct the searched nodes list."""
  277. for _, sub_node in cur_node.get_children():
  278. sub_nodes = []
  279. self._traverse_node_tree(sub_node, sub_nodes)
  280. sub_node_dict = {
  281. 'name': sub_node.node_name,
  282. 'type': sub_node.node_type,
  283. 'nodes': sub_nodes
  284. }
  285. search_node_list.append(sub_node_dict)
  286. def _parse_inputs(self, input_protos, node):
  287. """
  288. Parse `anf_ir_pb2.InputProto` object.
  289. Args:
  290. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  291. node (Node): Refer to `Node` object, it is used to log message and update input.
  292. """
  293. for input_proto in input_protos:
  294. if not input_proto.name:
  295. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  296. continue
  297. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  298. # Notice:
  299. # 1. The name in the input proto is the node id of the Node object.
  300. # 2. In the current step, the shape of source node cannot be obtained,
  301. # so it is set to empty list by default, and the next step will update it.
  302. # 3. Same with scope, set the default value first.
  303. input_attr = {
  304. "shape": [],
  305. "edge_type": edge_type,
  306. "independent_layout": False,
  307. 'data_type': ''
  308. }
  309. node.add_inputs(src_name=input_proto.name, input_attr=input_attr)
  310. def _parse_attributes(self, attributes, node):
  311. """
  312. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  313. Args:
  314. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  315. node (Node): Refer to `Node` object, it is used to log message and update attr.
  316. """
  317. for attr in attributes:
  318. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  319. message = f"The attribute value of node({node.name}) " \
  320. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  321. logger.warning(message)
  322. continue
  323. node.add_attr({attr.name: str(attr.value)})
  324. def _update_input_after_create_node(self):
  325. """Update the input of node after create node."""
  326. for node in self._normal_node_map.values():
  327. for src_node_id, input_attr in dict(node.inputs).items():
  328. node.delete_inputs(src_node_id)
  329. if not self._is_node_exist(node_id=src_node_id):
  330. message = f"The input node could not be found by node id({src_node_id}) " \
  331. f"while updating the input of the node({node})"
  332. logger.warning(message)
  333. continue
  334. src_node = self._get_normal_node(node_id=src_node_id)
  335. input_attr['shape'] = src_node.output_shape
  336. input_attr['data_type'] = src_node.output_data_type
  337. node.add_inputs(src_name=src_node.name, input_attr=input_attr)
  338. def _update_output_after_create_node(self):
  339. """Update the output of node after create node."""
  340. # Constants and parameter should not exist for input and output.
  341. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  342. for node in self._normal_node_map.values():
  343. for src_name, input_attr in node.inputs.items():
  344. src_node = self._get_normal_node(node_name=src_name)
  345. if src_node.type in filtered_node:
  346. continue
  347. src_node.add_outputs(node.name, input_attr)
  348. @staticmethod
  349. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  350. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  351. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name