You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. import time
  17. from mindinsight.datavisual.common.log import logger
  18. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. class MSGraph(Graph):
  24. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  25. def build_graph(self, proto_data):
  26. """
  27. Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`.
  28. Args:
  29. proto_data (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`.
  30. """
  31. logger.info("Start to build graph, graph name: %s.", proto_data.name)
  32. start_time = time.time()
  33. super(MSGraph, self).build_graph(proto_data)
  34. precision = 6
  35. time_consuming = round(time.time()-start_time, precision)
  36. logger.info("Build graph end, all node count: %s, const count: %s, parameter count: %s, time-consuming: %s s.",
  37. self.normal_node_count, len(self._const_node_temp_cache),
  38. len(self._parameter_node_temp_cache), time_consuming)
  39. def _parse_data(self, proto_data):
  40. """
  41. The proto data is parsed and all nodes are stored in the specified structure.
  42. Args:
  43. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  44. """
  45. logger.info("Start to parse graph proto data.")
  46. self._parse_op_nodes(proto_data.node)
  47. self._parse_parameters(proto_data.parameters)
  48. self._parse_consts(proto_data.const_vals)
  49. self._update_input_after_create_node()
  50. self._update_output_after_create_node()
  51. logger.info("Parse proto data end, normal node count(only contain op node, "
  52. "parameter, const): %s.", self.normal_node_count)
  53. def _parse_op_nodes(self, node_protos):
  54. """
  55. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  56. Args:
  57. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  58. """
  59. logger.debug("Start to parse op nodes from proto.")
  60. for node_proto in node_protos:
  61. if not node_proto.name:
  62. logger.warning("Finding a node with an empty name will not save it.")
  63. continue
  64. node_name = Node.create_node_name(scope=node_proto.scope,
  65. base_name=f'{node_proto.op_type}{node_proto.name}')
  66. node = Node(name=node_name, node_id=node_proto.name)
  67. node.type = node_proto.op_type
  68. logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, "
  69. "input count: %s", node.node_id, node.name, node_proto.name, len(node_proto.input))
  70. self._parse_attributes(node_proto.attribute, node)
  71. self._parse_inputs(node_proto.input, node)
  72. node.output_i = node_proto.output_i
  73. node.scope = node_proto.scope
  74. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  75. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type)
  76. self._cache_node(node)
  77. def _parse_parameters(self, parameter_protos):
  78. """
  79. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  80. Args:
  81. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  82. """
  83. logger.debug("Start to parse parameters from proto.")
  84. for parameter in parameter_protos:
  85. if not parameter.name:
  86. logger.warning("Finding a parameter with an empty name will not save it.")
  87. continue
  88. node = Node(name=parameter.name, node_id=parameter.name)
  89. node.type = NodeTypeEnum.PARAMETER.value
  90. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  91. attr = dict(
  92. type=self._get_data_type_by_parse_type_proto(parameter.type),
  93. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  94. )
  95. node.add_attr(attr)
  96. self._cache_node(node)
  97. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  98. "node def name: %s", node.node_id, node.name, parameter.name)
  99. def _parse_consts(self, consts):
  100. """
  101. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  102. Args:
  103. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  104. """
  105. logger.debug("Start to parse consts from proto.")
  106. for const in consts:
  107. if not const.key:
  108. logger.warning("Finding a const with an empty key will not save it.")
  109. continue
  110. node = Node(name=const.key, node_id=const.key)
  111. node.type = NodeTypeEnum.CONST.value
  112. node.add_attr({const.key: str(const.value)})
  113. if const.value.dtype == DataType.DT_TENSOR:
  114. shape = []
  115. for dim in const.value.tensor_val.dims:
  116. shape.append(dim)
  117. node.output_shape = shape
  118. self._cache_node(node)
  119. def _get_shape_by_parse_type_proto(self, type_proto):
  120. """
  121. Parse proto's `message TypeProto` to get shape information.
  122. Args:
  123. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  124. Returns:
  125. list, a list of shape.
  126. """
  127. shapes = []
  128. if type_proto.HasField('tensor_type'):
  129. tensor_type = type_proto.tensor_type
  130. tensor_shape_proto = tensor_type.shape
  131. for dim in tensor_shape_proto.dim:
  132. shapes.append(dim.size)
  133. if type_proto.HasField('sequence_type'):
  134. for elem_type in type_proto.sequence_type.elem_types:
  135. shapes.append(self._get_shape_by_parse_type_proto(elem_type))
  136. return shapes
  137. def _get_data_type_by_parse_type_proto(self, type_proto):
  138. """
  139. Get data type by parse type proto object.
  140. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  141. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  142. Args:
  143. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  144. Returns:
  145. str, the data type.
  146. """
  147. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  148. if type_proto.data_type == DataType.DT_TENSOR:
  149. tensor_type_proto = type_proto.tensor_type
  150. value = type_proto.tensor_type.elem_type
  151. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  152. return f'{data_type_name}[{elem_type_name}]'
  153. if type_proto.data_type == DataType.DT_TUPLE:
  154. data_types = []
  155. for elem_type in type_proto.sequence_type.elem_types:
  156. data_types.append(self._get_data_type_by_parse_type_proto(elem_type))
  157. return f'{data_type_name}{str(data_types)}'
  158. return data_type_name
  159. def _parse_inputs(self, input_protos, node):
  160. """
  161. Parse `anf_ir_pb2.InputProto` object.
  162. Args:
  163. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  164. node (Node): Refer to `Node` object, it is used to log message and update input.
  165. """
  166. for input_proto in input_protos:
  167. if not input_proto.name:
  168. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  169. continue
  170. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  171. # Notice:
  172. # 1. The name in the input proto is the node id of the Node object.
  173. # 2. In the current step, the shape of source node cannot be obtained,
  174. # so it is set to empty list by default, and the next step will update it.
  175. # 3. Same with scope, set the default value first.
  176. input_attr = {
  177. "shape": [],
  178. "edge_type": edge_type,
  179. "independent_layout": False,
  180. 'data_type': ''
  181. }
  182. node.add_input(src_name=input_proto.name, input_attr=input_attr)
  183. def _parse_attributes(self, attributes, node):
  184. """
  185. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  186. Args:
  187. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  188. node (Node): Refer to `Node` object, it is used to log message and update attr.
  189. """
  190. for attr in attributes:
  191. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  192. message = f"The attribute value of node({node.name}) " \
  193. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  194. logger.info(message)
  195. continue
  196. node.add_attr({attr.name: str(attr.value)})
  197. def _update_input_after_create_node(self):
  198. """Update the input of node after create node."""
  199. for node in self._normal_node_map.values():
  200. for src_node_id, input_attr in dict(node.input).items():
  201. node.delete_input(src_node_id)
  202. if not self._is_node_exist(node_id=src_node_id):
  203. message = f"The input node could not be found by node id({src_node_id}) " \
  204. f"while updating the input of the node({node})"
  205. logger.warning(message)
  206. continue
  207. src_node = self._get_normal_node(node_id=src_node_id)
  208. input_attr['shape'] = src_node.output_shape
  209. input_attr['data_type'] = src_node.output_data_type
  210. node.add_input(src_name=src_node.name, input_attr=input_attr)
  211. def _update_output_after_create_node(self):
  212. """Update the output of node after create node."""
  213. # Constants and parameter should not exist for input and output.
  214. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  215. for node in self._normal_node_map.values():
  216. for src_name, input_attr in node.input.items():
  217. src_node = self._get_normal_node(node_name=src_name)
  218. if src_node.type in filtered_node:
  219. continue
  220. src_node.add_output(node.name, input_attr)
  221. @staticmethod
  222. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  223. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  224. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name