# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """This file is used to define the MindSpore graph.""" import re import copy from mindinsight.datavisual.common.log import logger from .node import Node from .node import NodeTypeEnum from .graph import Graph from .graph import EdgeTypeEnum from .graph import DataTypeEnum class MSGraph(Graph): """The object describes the MindSpore graph, and it is defined in the anf_if proto file.""" def build_graph(self, graph_proto): """ Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`, and set status to loading. Args: graph_proto (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`. """ logger.info("Start to build graph.") self._build_leaf_nodes(graph_proto) self._build_polymeric_nodes() self._build_name_scope_nodes() self._update_polymeric_input_output() logger.info("Build graph end, normal node count: %s, polymeric node " "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) def _build_leaf_nodes(self, graph_proto): """ Build leaf node from graph proto. Left node will contain operation node, parameter node, const node. Args: graph_proto (anf_ir_pb2.model_proto.graph): Refer to anf_ir_pb2.model_proto.graph. """ logger.info("Start to build leaf nodes.") leaf_node_id_map_name = {} const_nodes_map = {} for node_def in graph_proto.node: node = self._parse_graph_proto_node(node_def) leaf_node_id_map_name.update({node.node_id: node.name}) for parameter in graph_proto.parameters: node = self._parse_graph_proto_parameter(parameter) const_nodes_map.update({node.name: node}) for i, const in enumerate(graph_proto.const_vals): node_id = 'const_{}'.format(i) node = self._parse_graph_proto_const(const, node_id) const_nodes_map.update({const.key: node}) self._calc_input(leaf_node_id_map_name, graph_proto, const_nodes_map) self._calc_output() logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, " "left node count: %s.", len(self._normal_nodes), len(self._node_groups), len(self._leaf_nodes)) def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map): """ Calc input for every leaf node. Args: leaf_node_id_map_name (dict[str, str]): Format is {'node_id': 'node_name'}. graph_proto (anf_ir_pb2.model_proto.graph): See anf_ir_pb2.model_proto.graph. const_nodes_map (dict[str, Node]): Format is {'node name': }. """ logger.debug("Start to calc input.") for node_def in graph_proto.node: node_name = leaf_node_id_map_name[node_def.name] node = self._leaf_nodes[node_name] for input_def in node_def.input: edge_type = EdgeTypeEnum.DATA.value if input_def.type == "CONTROL_EDGE": edge_type = EdgeTypeEnum.CONTROL.value if const_nodes_map.get(input_def.name): const_node = copy.deepcopy(const_nodes_map[input_def.name]) src_name = '{}/{}'.format(node.name_scope, input_def.name) if not self._normal_nodes.get(src_name): const_node.name = src_name const_node.name_scope = node.name_scope self._normal_nodes.update({src_name: const_node}) self._leaf_nodes.update({src_name: const_node}) src_node = self._leaf_nodes.get(src_name) else: src_name = leaf_node_id_map_name.get(input_def.name) if not src_name: logger.warning("The input_def name '%s' in node '%s' is invalid, " "will be ignore.", input_def.name, node_name) continue src_node = self._leaf_nodes.get(src_name) if src_node is None: logger.warning("The input '%s' in node '%s' is not in " "leaf nodes.", src_name, node_name) continue input_item = { src_name: { "shape": src_node.shape, "edge_type": edge_type, "scope": NodeTypeEnum.NAME_SCOPE.value } } node.update_input(input_item) if self._normal_nodes.get(node_name): self._normal_nodes[node_name] = node else: group_name = self._create_group_name(node.name_scope, node.node_type, node.name) self._node_groups[group_name][node.name] = node def _calc_output(self): """Calc output of every node.""" logger.debug("Start to calc output.") for name, node in self._leaf_nodes.items(): if node.node_type == NodeTypeEnum.CONST.value: continue for src_name, input_attr in node.input.items(): src_node = self._leaf_nodes[src_name] if src_node.node_type == NodeTypeEnum.CONST.value: continue if self._normal_nodes.get(src_name): self._normal_nodes[src_name].update_output({name: input_attr}) else: group_name = self._create_group_name(src_node.name_scope, src_node.node_type, src_node.name) self._node_groups[group_name][src_name].update_output({name: input_attr}) def _parse_graph_proto_node(self, node_def): """ Parse `anf_ir_pb2.model_proto.graph.node_def`, and create a a node. Args: node_def (anf_ir_pb2.model_proto.graph.node_def): Refer to anf_ir_pb2.model_proto.graph.node_def. Returns: Node, a `Node` object. """ node_name = '/'.join([node_def.scope, node_def.op_type])+node_def.name node = Node(name=node_name, node_id=node_def.name) node.node_type = node_def.op_type logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, " "input count: %s", node.node_id, node.name, node_def.name, len(node_def.input)) for attr in node_def.attribute: node.update_attr({attr.name: str(attr.value)}) node.output_i = node_def.output_i node.name_scope = node_def.scope output_type = node_def.output_type shape = self._parse_type_proto(output_type) node.shape = shape self._leaf_nodes.update({node.name: node}) group_name = self._create_group_name(node.name_scope, node.node_type, node.name) if group_name is not None: node_dict = self._node_groups.get(group_name, {}) node_dict.update({node.name: node}) self._node_groups.update({group_name: node_dict}) else: self._normal_nodes.update({node.name: node}) return node def _parse_graph_proto_parameter(self, parameter): """ Parse anf_ir_pb2.model_proto.graph.parameter, and create a parameter node. Args: parameter (anf_ir_pb2.model_proto.graph.parameter): Refer to anf_ir_pb2.model_proto.graph.parameter. Returns: Node, a `Node` object. """ node = Node(name=parameter.name, node_id=parameter.name) node.node_type = NodeTypeEnum.PARAMETER.value node.shape = self._parse_type_proto(parameter.type) logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, " "node def name: %s", node.node_id, node.name, parameter.name) return node def _parse_graph_proto_const(self, const, const_node_id): """ Parse anf_ir_pb2.model_proto.graph.const, and create a const node. Args: const (anf_ir_pb2.model_proto.graph.const): Refer to anf_ir_pb2.model_proto.graph.const const_node_id (str): The id of the new const node, it should be unique in graph. Returns: Node, a `Node` object. """ node = Node(name=const.key, node_id=const_node_id) node.node_type = NodeTypeEnum.CONST.value node.update_attr({const.key: str(const.value)}) if const.value.dtype == DataTypeEnum.DT_TENSOR.value: shape = [] for dim in const.value.tensor_val.dims: shape.append(dim) node.shape = shape return node def _parse_type_proto(self, type_proto): """ Parse proto's `message TypeProto` to get shape information. Args: type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto. Returns: list, a list of shape. """ shapes = [] if type_proto.HasField('tensor_type'): tensor_type = type_proto.tensor_type tensor_shape_proto = tensor_type.shape for dim in tensor_shape_proto.dim: shapes.append(dim.size) if type_proto.HasField('sequence_type'): for elem_type in type_proto.sequence_type.elem_types: shapes.append(self._parse_type_proto(elem_type)) return shapes def _create_group_name(self, name_scope, node_type, node_name): """ Create group name by node name, name scope, node type. Only nodes that conform to the rules are aggregated. Args: name_scope (str): The node name scope. node_type (str): The node type. node_name (str): The node name. Returns: Optional[str], if match the rules will return a group name, else return None. """ group_types = ['Reshape', 'Variable'] pattern_names = r'.*?/Cast-op\d+' if node_type in group_types: group_name = name_scope + '/' + node_type if name_scope else node_type return group_name if node_type == 'FrameworkOp' and re.search(pattern_names, node_name): group_name = name_scope + '/' + 'Cast-op' if name_scope else 'Cast-op' return group_name return None