You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. class MSGraph(Graph):
  24. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  25. def _parse_data(self, proto_data):
  26. """
  27. The proto data is parsed and all nodes are stored in the specified structure.
  28. Args:
  29. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  30. """
  31. logger.info("Start to parse graph proto data.")
  32. self._parse_op_nodes(proto_data.node)
  33. self._parse_parameters(proto_data.parameters)
  34. self._parse_consts(proto_data.const_vals)
  35. self._update_input_after_create_node()
  36. self._update_output_after_create_node()
  37. logger.info("Parse proto data end, normal node count(only contain op node, "
  38. "parameter, const): %s.", self.normal_node_count)
  39. def _parse_op_nodes(self, node_protos):
  40. """
  41. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  42. Args:
  43. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  44. """
  45. logger.debug("Start to parse op nodes from proto.")
  46. for node_proto in node_protos:
  47. if not node_proto.name:
  48. logger.warning("Finding a node with an empty name will not save it.")
  49. continue
  50. if not node_proto.full_name or any(
  51. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  52. node_name = Node.create_node_name(scope=node_proto.scope,
  53. base_name=f'{node_proto.op_type}{node_proto.name}')
  54. else:
  55. node_name = node_proto.full_name
  56. node = Node(name=node_name, node_id=node_proto.name)
  57. node.full_name = node_proto.full_name
  58. node.type = node_proto.op_type
  59. self._parse_attributes(node_proto.attribute, node)
  60. self._parse_inputs(node_proto.input, node)
  61. node.output_i = node_proto.output_i
  62. node.scope = node_proto.scope
  63. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  64. node.output_nums = len(node.output_shape)
  65. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type, node)
  66. self._cache_node(node)
  67. def _parse_parameters(self, parameter_protos):
  68. """
  69. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  70. Args:
  71. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  72. """
  73. logger.debug("Start to parse parameters from proto.")
  74. for parameter in parameter_protos:
  75. if not parameter.name:
  76. logger.warning("Finding a parameter with an empty name will not save it.")
  77. continue
  78. node = Node(name=parameter.name, node_id=parameter.name)
  79. node.type = NodeTypeEnum.PARAMETER.value
  80. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  81. node.output_nums = len(node.output_shape)
  82. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type, node)
  83. attr = dict(
  84. type=self._get_data_type_by_parse_type_proto(parameter.type, node),
  85. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  86. )
  87. node.add_attr(attr)
  88. self._cache_node(node)
  89. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  90. "node def name: %s", node.node_id, node.name, parameter.name)
  91. def _parse_consts(self, consts):
  92. """
  93. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  94. Args:
  95. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  96. """
  97. logger.debug("Start to parse consts from proto.")
  98. for const in consts:
  99. if not const.key:
  100. logger.warning("Finding a const with an empty key will not save it.")
  101. continue
  102. node = Node(name=const.key, node_id=const.key)
  103. node.type = NodeTypeEnum.CONST.value
  104. if const.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  105. node.add_attr({const.key: 'dtype: ' + DataType.Name(const.value.dtype)})
  106. else:
  107. node.add_attr({const.key: str(const.value)})
  108. if const.value.dtype == DataType.DT_TENSOR:
  109. shape = list(const.value.tensor_val.dims)
  110. node.output_shape.append(shape)
  111. if const.value.tensor_val.HasField('data_type'):
  112. node.elem_types.append(DataType.Name(const.value.tensor_val.data_type))
  113. else:
  114. node.elem_types.append(DataType.Name(const.value.dtype))
  115. # dim is zero
  116. node.output_shape.append([])
  117. node.output_nums = len(node.output_shape)
  118. self._cache_node(node)
  119. def _get_shape_by_parse_type_proto(self, type_proto):
  120. """
  121. Parse proto's `message TypeProto` to get shape information.
  122. Args:
  123. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  124. Returns:
  125. list, a list of shape.
  126. """
  127. shapes = []
  128. if type_proto.HasField('data_type'):
  129. if type_proto.data_type != DataType.DT_TENSOR and \
  130. type_proto.data_type != DataType.DT_TUPLE:
  131. # Append an empty list as a placeholder
  132. # for the convenience of output number calculation.
  133. shapes.append([])
  134. return shapes
  135. if type_proto.HasField('tensor_type'):
  136. tensor_type = type_proto.tensor_type
  137. tensor_shape_proto = tensor_type.shape
  138. shape = [dim.size for dim in tensor_shape_proto.dim]
  139. shapes.append(shape)
  140. if type_proto.HasField('sequence_type'):
  141. for elem_type in type_proto.sequence_type.elem_types:
  142. shapes.extend(self._get_shape_by_parse_type_proto(elem_type))
  143. return shapes
  144. def _get_data_type_by_parse_type_proto(self, type_proto, node):
  145. """
  146. Get data type by parse type proto object.
  147. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  148. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  149. Args:
  150. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  151. Returns:
  152. str, the data type.
  153. """
  154. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  155. if type_proto.data_type == DataType.DT_TENSOR:
  156. tensor_type_proto = type_proto.tensor_type
  157. value = type_proto.tensor_type.elem_type
  158. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  159. node.elem_types.append(elem_type_name)
  160. return f'{data_type_name}[{elem_type_name}]'
  161. if type_proto.data_type == DataType.DT_TUPLE:
  162. data_types = []
  163. for elem_type in type_proto.sequence_type.elem_types:
  164. data_types.append(self._get_data_type_by_parse_type_proto(elem_type, node))
  165. return f'{data_type_name}{str(data_types)}'
  166. node.elem_types.append(data_type_name)
  167. return data_type_name
  168. def _parse_inputs(self, input_protos, node):
  169. """
  170. Parse `anf_ir_pb2.InputProto` object.
  171. Args:
  172. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  173. node (Node): Refer to `Node` object, it is used to log message and update input.
  174. """
  175. for input_proto in input_protos:
  176. if not input_proto.name:
  177. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  178. continue
  179. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  180. # Notice:
  181. # 1. The name in the input proto is the node id of the Node object.
  182. # 2. In the current step, the shape of source node cannot be obtained,
  183. # so it is set to empty list by default, and the next step will update it.
  184. # 3. Same with scope, set the default value first.
  185. input_attr = {
  186. "shape": [],
  187. "edge_type": edge_type,
  188. "independent_layout": False,
  189. 'data_type': ''
  190. }
  191. node.add_inputs(src_name=input_proto.name, input_attr=input_attr)
  192. def _parse_attributes(self, attributes, node):
  193. """
  194. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  195. Args:
  196. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  197. node (Node): Refer to `Node` object, it is used to log message and update attr.
  198. """
  199. for attr in attributes:
  200. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  201. message = f"The attribute value of node({node.name}) " \
  202. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  203. logger.warning(message)
  204. continue
  205. node.add_attr({attr.name: str(attr.value)})
  206. def _update_input_after_create_node(self):
  207. """Update the input of node after create node."""
  208. for node in self._normal_node_map.values():
  209. for src_node_id, input_attr in dict(node.inputs).items():
  210. node.delete_inputs(src_node_id)
  211. if not self._is_node_exist(node_id=src_node_id):
  212. message = f"The input node could not be found by node id({src_node_id}) " \
  213. f"while updating the input of the node({node})"
  214. logger.warning(message)
  215. continue
  216. src_node = self._get_normal_node(node_id=src_node_id)
  217. input_attr['shape'] = src_node.output_shape
  218. input_attr['data_type'] = src_node.output_data_type
  219. node.add_inputs(src_name=src_node.name, input_attr=input_attr)
  220. def _update_output_after_create_node(self):
  221. """Update the output of node after create node."""
  222. # Constants and parameter should not exist for input and output.
  223. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  224. for node in self._normal_node_map.values():
  225. for src_name, input_attr in node.inputs.items():
  226. src_node = self._get_normal_node(node_name=src_name)
  227. if src_node.type in filtered_node:
  228. continue
  229. src_node.add_outputs(node.name, input_attr)
  230. @staticmethod
  231. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  232. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  233. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name