You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. import re
  17. import copy
  18. from mindinsight.datavisual.common.log import logger
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. from .graph import DataTypeEnum
  24. class MSGraph(Graph):
  25. """The object describes the MindSpore graph, and it is defined in the anf_if proto file."""
  26. def build_graph(self, graph_proto):
  27. """
  28. Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`, and set status to loading.
  29. Args:
  30. graph_proto (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`.
  31. """
  32. logger.info("Start to build graph.")
  33. self._build_leaf_nodes(graph_proto)
  34. self._build_polymeric_nodes()
  35. self._build_name_scope_nodes()
  36. self._update_polymeric_input_output()
  37. logger.info("Build graph end, normal node count: %s, polymeric node "
  38. "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
  39. def _build_leaf_nodes(self, graph_proto):
  40. """
  41. Build leaf node from graph proto.
  42. Left node will contain operation node, parameter node, const node.
  43. Args:
  44. graph_proto (anf_ir_pb2.model_proto.graph): Refer to anf_ir_pb2.model_proto.graph.
  45. """
  46. logger.info("Start to build leaf nodes.")
  47. leaf_node_id_map_name = {}
  48. const_nodes_map = {}
  49. for node_def in graph_proto.node:
  50. node = self._parse_graph_proto_node(node_def)
  51. leaf_node_id_map_name.update({node.node_id: node.name})
  52. for parameter in graph_proto.parameters:
  53. node = self._parse_graph_proto_parameter(parameter)
  54. const_nodes_map.update({node.name: node})
  55. for i, const in enumerate(graph_proto.const_vals):
  56. node_id = 'const_{}'.format(i)
  57. node = self._parse_graph_proto_const(const, node_id)
  58. const_nodes_map.update({const.key: node})
  59. self._calc_input(leaf_node_id_map_name, graph_proto, const_nodes_map)
  60. self._calc_output()
  61. logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, "
  62. "left node count: %s.", len(self._normal_nodes), len(self._node_groups),
  63. len(self._leaf_nodes))
  64. def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map):
  65. """
  66. Calc input for every leaf node.
  67. Args:
  68. leaf_node_id_map_name (dict[str, str]): Format is {'node_id': 'node_name'}.
  69. graph_proto (anf_ir_pb2.model_proto.graph): See anf_ir_pb2.model_proto.graph.
  70. const_nodes_map (dict[str, Node]): Format is {'node name': <Const node>}.
  71. """
  72. logger.debug("Start to calc input.")
  73. for node_def in graph_proto.node:
  74. node_name = leaf_node_id_map_name[node_def.name]
  75. node = self._leaf_nodes[node_name]
  76. for input_def in node_def.input:
  77. edge_type = EdgeTypeEnum.DATA.value
  78. if input_def.type == "CONTROL_EDGE":
  79. edge_type = EdgeTypeEnum.CONTROL.value
  80. if const_nodes_map.get(input_def.name):
  81. const_node = copy.deepcopy(const_nodes_map[input_def.name])
  82. src_name = '{}/{}'.format(node.name_scope, input_def.name)
  83. if not self._normal_nodes.get(src_name):
  84. const_node.name = src_name
  85. const_node.name_scope = node.name_scope
  86. self._normal_nodes.update({src_name: const_node})
  87. self._leaf_nodes.update({src_name: const_node})
  88. src_node = self._leaf_nodes.get(src_name)
  89. else:
  90. src_name = leaf_node_id_map_name.get(input_def.name)
  91. if not src_name:
  92. logger.warning("The input_def name '%s' in node '%s' is invalid, "
  93. "will be ignore.", input_def.name, node_name)
  94. continue
  95. src_node = self._leaf_nodes.get(src_name)
  96. if src_node is None:
  97. logger.warning("The input '%s' in node '%s' is not in "
  98. "leaf nodes.", src_name, node_name)
  99. continue
  100. input_item = {
  101. src_name: {
  102. "shape": src_node.shape,
  103. "edge_type": edge_type,
  104. "scope": NodeTypeEnum.NAME_SCOPE.value
  105. }
  106. }
  107. node.update_input(input_item)
  108. if self._normal_nodes.get(node_name):
  109. self._normal_nodes[node_name] = node
  110. else:
  111. group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
  112. self._node_groups[group_name][node.name] = node
  113. def _calc_output(self):
  114. """Calc output of every node."""
  115. logger.debug("Start to calc output.")
  116. for name, node in self._leaf_nodes.items():
  117. if node.node_type == NodeTypeEnum.CONST.value:
  118. continue
  119. for src_name, input_attr in node.inputs.items():
  120. src_node = self._leaf_nodes[src_name]
  121. if src_node.node_type == NodeTypeEnum.CONST.value:
  122. continue
  123. if self._normal_nodes.get(src_name):
  124. self._normal_nodes[src_name].update_output({name: input_attr})
  125. else:
  126. group_name = self._create_group_name(src_node.name_scope,
  127. src_node.node_type, src_node.name)
  128. self._node_groups[group_name][src_name].update_output({name: input_attr})
  129. def _parse_graph_proto_node(self, node_def):
  130. """
  131. Parse `anf_ir_pb2.model_proto.graph.node_def`, and create a a node.
  132. Args:
  133. node_def (anf_ir_pb2.model_proto.graph.node_def): Refer to anf_ir_pb2.model_proto.graph.node_def.
  134. Returns:
  135. Node, a `Node` object.
  136. """
  137. node_name = '/'.join([node_def.scope, node_def.op_type])+node_def.name
  138. node = Node(name=node_name, node_id=node_def.name)
  139. node.node_type = node_def.op_type
  140. logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, "
  141. "input count: %s", node.node_id, node.name, node_def.name, len(node_def.input))
  142. for attr in node_def.attribute:
  143. node.update_attr({attr.name: str(attr.value)})
  144. node.output_i = node_def.output_i
  145. node.name_scope = node_def.scope
  146. output_type = node_def.output_type
  147. shape = self._parse_type_proto(output_type)
  148. node.shape = shape
  149. self._leaf_nodes.update({node.name: node})
  150. group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
  151. if group_name is not None:
  152. node_dict = self._node_groups.get(group_name, {})
  153. node_dict.update({node.name: node})
  154. self._node_groups.update({group_name: node_dict})
  155. else:
  156. self._normal_nodes.update({node.name: node})
  157. return node
  158. def _parse_graph_proto_parameter(self, parameter):
  159. """
  160. Parse anf_ir_pb2.model_proto.graph.parameter, and create a parameter node.
  161. Args:
  162. parameter (anf_ir_pb2.model_proto.graph.parameter): Refer to anf_ir_pb2.model_proto.graph.parameter.
  163. Returns:
  164. Node, a `Node` object.
  165. """
  166. node = Node(name=parameter.name, node_id=parameter.name)
  167. node.node_type = NodeTypeEnum.PARAMETER.value
  168. node.shape = self._parse_type_proto(parameter.type)
  169. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  170. "node def name: %s", node.node_id, node.name, parameter.name)
  171. return node
  172. def _parse_graph_proto_const(self, const, const_node_id):
  173. """
  174. Parse anf_ir_pb2.model_proto.graph.const, and create a const node.
  175. Args:
  176. const (anf_ir_pb2.model_proto.graph.const): Refer to anf_ir_pb2.model_proto.graph.const
  177. const_node_id (str): The id of the new const node, it should be unique in graph.
  178. Returns:
  179. Node, a `Node` object.
  180. """
  181. node = Node(name=const.key, node_id=const_node_id)
  182. node.node_type = NodeTypeEnum.CONST.value
  183. node.update_attr({const.key: str(const.value)})
  184. if const.value.dtype == DataTypeEnum.DT_TENSOR.value:
  185. shape = []
  186. for dim in const.value.tensor_val.dims:
  187. shape.append(dim)
  188. node.shape = shape
  189. return node
  190. def _parse_type_proto(self, type_proto):
  191. """
  192. Parse proto's `message TypeProto` to get shape information.
  193. Args:
  194. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  195. Returns:
  196. list, a list of shape.
  197. """
  198. shapes = []
  199. if type_proto.HasField('tensor_type'):
  200. tensor_type = type_proto.tensor_type
  201. tensor_shape_proto = tensor_type.shape
  202. for dim in tensor_shape_proto.dim:
  203. shapes.append(dim.size)
  204. if type_proto.HasField('sequence_type'):
  205. for elem_type in type_proto.sequence_type.elem_types:
  206. shapes.append(self._parse_type_proto(elem_type))
  207. return shapes
  208. def _create_group_name(self, name_scope, node_type, node_name):
  209. """
  210. Create group name by node name, name scope, node type.
  211. Only nodes that conform to the rules are aggregated.
  212. Args:
  213. name_scope (str): The node name scope.
  214. node_type (str): The node type.
  215. node_name (str): The node name.
  216. Returns:
  217. Optional[str], if match the rules will return a group name, else return None.
  218. """
  219. group_types = ['Reshape', 'Variable']
  220. pattern_names = r'.*?/Cast-op\d+'
  221. if node_type in group_types:
  222. group_name = name_scope + '/' + node_type if name_scope else node_type
  223. return group_name
  224. if node_type == 'FrameworkOp' and re.search(pattern_names, node_name):
  225. group_name = name_scope + '/' + 'Cast-op' if name_scope else 'Cast-op'
  226. return group_name
  227. return None

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。