|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- This file is used to define the basic graph.
- """
- import time
-
- from enum import Enum
- from collections import defaultdict
-
- from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
- from mindinsight.datavisual.common.log import logger
- from mindinsight.utils.exceptions import ParamMissError
- from mindinsight.utils.exceptions import ParamValueError
- from .node import NodeTypeEnum
- from .node import Node
-
-
- class EdgeTypeEnum(Enum):
- """Node edge type enum."""
- CONTROL = 'control'
- DATA = 'data'
-
-
- class Graph:
- """The `Graph` object is used to describe a graph file."""
- # Limit the size of a single attribute value per node to avoid storing too much data
- MAX_NODE_ATTRIBUTE_VALUE_BYTES = 1024
-
- # In the same scope, the number of children of the same type exceeds this threshold, and we will combine them.
- MIN_GROUP_NODE_COUNT = 5
-
- def __init__(self):
- # Used to cache all nodes, and the key is node name, value is `Node` object.
- self._normal_node_map = {}
- self._node_id_map_name = {}
-
- # The additional caching of Const and Parameter is to handle the Const
- # and Parameter nodes separately later.
- self._const_node_temp_cache = {}
- self._parameter_node_temp_cache = {}
-
- self._leaf_nodes = {}
- self._full_name_map_name = {}
-
- def build_graph(self, proto_data):
- """This method is used to build the graph."""
- logger.info("Start to build graph")
- start_time = time.time()
-
- # Notice:
- # The following methods are interdependent and cannot be switched at will.
- self._parse_data(proto_data)
- self._add_variable_nodes(NodeTypeEnum.PARAMETER.value)
- self._build_aggregation_scope_nodes()
- self._process_independent_layout()
- self._build_name_scope_nodes()
-
- # Since const nodes are not aggregated, adding them at the end can save a lot of computation.
- self._add_variable_nodes(NodeTypeEnum.CONST.value)
- self._calc_subnode_count()
- self._leaf_nodes = self._get_leaf_nodes()
- self._full_name_map_name = self._get_leaf_node_full_name_map()
-
- precision = 6
- time_consuming = round(time.time() - start_time, precision)
- logger.info("Build graph end, all node count: %s, const count: %s, parameter count: %s, time-consuming: %s s.",
- self.normal_node_count, len(self._const_node_temp_cache),
- len(self._parameter_node_temp_cache), time_consuming)
-
- def _get_leaf_nodes(self):
- """
- Get all leaf nodes, including normal leaf nodes, const nodes and param nodes.
- """
- leaf_nodes = {}
- for node_name, node in self._normal_node_map.items():
- # update full name
- if not node.full_name:
- node.full_name = node.name
- if not node.type or node.type.endswith('_scope'):
- continue
- leaf_nodes[node_name] = node
-
- return leaf_nodes
-
- def _get_leaf_node_full_name_map(self):
- """Get node by debugger name."""
- full_name_map = {}
- for name, node in self._leaf_nodes.items():
- if not node.full_name:
- logger.warning("Node %s does not have full name.", name)
- continue
- full_name_map[node.full_name] = name
-
- return full_name_map
-
- def exist_node(self, name):
- """
- Check node exist in graph.
-
- Args:
- name (str): The node name.
-
- Returns:
- bool, if node exists, will return True.
-
- """
- if name is None:
- return False
- return self._is_node_exist(node_name=name)
-
- def list_node_by_scope(self, scope=None):
- """
- List nodes by the scope of nodes. The scope of a node is the same as its parent node name.
-
- Args:
- scope (str): A scope of nodes.
-
- Returns:
- list[dict], a list object contain `Node` object.
- """
- scope = "" if scope is None else scope
- nodes = []
- for node in self._normal_node_map.values():
- if node.scope == scope:
- nodes.append(node.to_dict())
- return nodes
-
- def search_node_names(self, content, offset, limit):
- """
- Search node names by content.
-
- Args:
- content (Union[str, None]): This content can be the key content of the node to search,
- if None, will get all node names.
- offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
- limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
-
- Returns:
- list[str], a list of node names.
- """
- if content is not None:
- content = content.lower()
- catch_names = [name for name in self._normal_node_map if content in name.lower()]
- else:
- catch_names = list(self._normal_node_map)
- catch_names = sorted(catch_names)
- real_offset = offset * limit
- return catch_names[real_offset:real_offset+limit]
-
- def search_single_node(self, node_name):
- """
- Search node, and return every layer nodes until this node.
-
- Args:
- node_name (str): The name of node.
-
- Returns:
- dict, a dict object, format is :
- item_object = {'nodes': [<Node object>],
- 'scope_name': '<Node scope>',
- 'children': {<item_object>}}
- """
- if node_name and not self.exist_node(name=node_name):
- raise NodeNotInGraphError(node_name=node_name)
-
- response = {}
- nodes = self.list_node_by_scope()
- response.update({
- 'nodes': nodes,
- 'scope_name': '',
- 'children': {}
- })
-
- children = response['children']
-
- index = node_name.find('/')
- while index != -1:
- scope = node_name[:index]
- nodes = self.list_node_by_scope(scope)
- children.update({
- 'nodes': nodes,
- 'scope_name': scope,
- 'children': {}
- })
- children = children['children']
-
- index = node_name.find('/', index+1)
-
- return response
-
- def _parse_data(self, proto_data):
- """
- This method will parse the data and create basic nodes to store in the cache.
-
- The graph is then built based on the cache.
- """
- raise NotImplementedError("Before you can build a graph, you need to parse the data.")
-
- def _build_name_scope_nodes(self):
- """
- Build name scope node by every node name.
-
- We create the name scope node by the slash('/') in the node name.
- For example, if a node name is "Default/add", we generate a scope named 'Default' based on slash('/') and
- create a name scope node named 'Default'.
- """
- logger.info("Start to build name scope nodes.")
- scope_node_map = {}
- for name, node in self._normal_node_map.items():
- index = name.find('/')
- pre_index = None
- while index > 0:
- scope = name[:index]
- scope_node = scope_node_map.get(scope)
- if scope_node is None:
- if self._is_node_exist(node_name=scope):
- exist_node = self._get_normal_node(node_name=scope)
- if exist_node.type == NodeTypeEnum.AGGREGATION_SCOPE.value:
- # This scope is aggregation scope, so we don't have to do anything.
- pre_index = index
- index = name.find('/', pre_index + 1)
- continue
-
- # We find a node name that conflicts with the current scope and rename the node
- self._update_conflict_node(conflict_name=scope)
-
- # We create a node for current scope.
- scope_node = Node(scope, node_id=scope)
- scope_node.type = NodeTypeEnum.NAME_SCOPE.value
- scope_node.scope = '' if pre_index is None else name[:pre_index]
- scope_node_map.update({scope_node.name: scope_node})
-
- # Inherit input and output from sub nodes.
- self._inherit_input_output_from_subnode(scope_node, subnode_list=[node])
-
- pre_index = index
- index = name.find('/', pre_index+1)
-
- # Cache all the scope node to normal node dict
- for node in scope_node_map.values():
- self._cache_node(node)
-
- def _update_conflict_node(self, conflict_name):
- conflict_node = self._get_normal_node(node_name=conflict_name)
- base_name = conflict_name.split('/')[-1]
- new_name = Node.create_node_name(scope=conflict_node.scope, base_name=f'({base_name})')
- self._update_node_name_of_cache(conflict_node, new_name, update_parent=True)
-
- def _inherit_input_output_from_subnode(self, parent_node, subnode_list, filtered_type=None):
- """
- Adds the input and output of all direct child nodes to the current node.
-
- Args:
- parent_node (Node): The nodes that inherit the input and output of the child nodes.
- subnode_list (list[Node]): A list of child nodes that are inherited from the input and output.
- filtered_type (set(str)): Filter some input and output that do not require inheritance
- based on the node type. Default is filter const node.
-
- Note:
- - Only the inputs and outputs of the external scope are inherited.
- - Before add_const_node method, if the input is a const,
- the scope of the const node is not startswith the name of parent node.
- So in this scenario, we need to filter the const nodes.
- """
- filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type
- for method in ['input', 'output', 'proxy_input', 'proxy_output']:
- for node in subnode_list:
- for item_name, item_attr in getattr(node, method).items():
- target_node = self._get_normal_node(node_name=item_name)
- if item_name.startswith(f'{parent_node.name}/'):
- # Own scope, ignore
- continue
-
- if target_node.type in filtered_type:
- continue
-
- getattr(parent_node, f'add_{method}')(item_name, item_attr)
-
- def _build_aggregation_scope_nodes(self):
- """
- Under the same scope, the number of nodes of the same type will be aggregated after exceeding the set threshold.
-
- Note:
- The threshold value refers to the `MIN_GROUP_NODE_COUNT`.
- """
- logger.info("Start to build aggregation scope nodes.")
- group_node_map, filtered_group_names = self._find_group_nodes()
-
- # create merge scope nodes
- aggregation_scope_node_map = {}
- for i, group_name in enumerate(filtered_group_names):
- slash_index = group_name.rfind('/')
- if slash_index != -1:
- scope, op_type = group_name[:slash_index], group_name[slash_index+1:]
- else:
- scope, op_type = '', group_name
-
- count = len(group_node_map.get(group_name))
- aggregation_node_name = Node.create_node_name(scope=scope, base_name=f'{op_type}[{count}]_{i}')
- aggregation_scope_node = Node(name=aggregation_node_name, node_id=aggregation_node_name)
- aggregation_scope_node.subnode_count = count
- aggregation_scope_node.scope = scope
- aggregation_scope_node.type = NodeTypeEnum.AGGREGATION_SCOPE.value
-
- # Update the name and scope of all children nodes
- for node in group_node_map[group_name]:
- base_name = node.name.split('/')[-1]
- new_name = Node.create_node_name(scope=aggregation_node_name, base_name=base_name)
- node.scope = aggregation_node_name
-
- # Since the name scope has not been created, there is no need to update the parent node.
- self._update_node_name_of_cache(node, new_name, update_parent=False)
-
- # Cache this node
- self._cache_node(aggregation_scope_node)
- aggregation_scope_node_map.update({group_name: aggregation_scope_node})
-
- # Adds the input and output of all direct child nodes to the current node.
- for group_name, node in aggregation_scope_node_map.items():
- self._inherit_input_output_from_subnode(node, group_node_map[group_name])
-
- def _find_group_nodes(self):
- """
- Find nodes that can be grouped into a group.
-
- For direct child nodes in a scope, we divide them into multiple groups by node type.
- However, we will exclude several types of child nodes,
- because these types of nodes are not operational nodes.
- """
- exclude_types = {
- NodeTypeEnum.CONST.value,
- NodeTypeEnum.NAME_SCOPE.value,
- }
-
- group_node_map = defaultdict(list)
- for node in self._normal_node_map.values():
- if node.type in exclude_types:
- continue
- group_name = Node.create_node_name(scope=node.scope, base_name=node.type)
- group_node_map[group_name].append(node)
-
- # filter can group scope.
- filtered_group_names = []
- for name, nodes in group_node_map.items():
- if len(nodes) < self.MIN_GROUP_NODE_COUNT:
- continue
- filtered_group_names.append(name)
-
- return group_node_map, filtered_group_names
-
- def _add_variable_nodes(self, node_type):
- """
- We create the Const nodes or Parameter nodes in this method.
-
- Args:
- node_type (str): Decide which type of node to add.
- Optional is `NodeTypeEnum.CONST.value` and `NodeTypeEnum.PARAMETER.value`.
-
- Note:
- This method relies on the presence of data in the const cache or parameter cache.
- """
- logger.info("Start to add %s nodes to each scope in graph.", node_type)
- node_map = {}
- for node in self._normal_node_map.values():
- for src_name, input_attr in dict(node.input).items():
-
- if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name):
- continue
-
- if node_type == NodeTypeEnum.PARAMETER.value and not self._parameter_node_temp_cache.get(src_name):
- continue
-
- variable_name = Node.create_node_name(scope=node.scope, base_name=src_name)
- if node_map.get(variable_name):
- # There is no need to create the node repeatedly
- variable_node = node_map.get(variable_name)
- else:
- cache_node = self._get_normal_node(node_name=src_name)
- variable_node = Node(name=variable_name, node_id=variable_name)
- Node.copy_node_without_input_output(cache_node, variable_node)
- variable_node.scope = node.scope
-
- variable_node.add_output(dst_name=node.name, output_attr=input_attr)
- node_map.update({variable_name: variable_node})
-
- node.delete_input(src_name)
- node.add_input(variable_name, input_attr)
-
- for node in node_map.values():
- self._cache_node(node)
-
- # Remove nodes that are not used in the cache.
- if node_type == NodeTypeEnum.CONST.value:
- unused_names = set(self._const_node_temp_cache) - set(node_map)
- elif node_type == NodeTypeEnum.PARAMETER.value:
- unused_names = set(self._parameter_node_temp_cache) - set(node_map)
- else:
- raise ParamValueError("The node type should be const or parameter.")
-
- self._delete_nodes_of_cache(unused_names)
-
- def _calc_subnode_count(self):
- """Calc all the direct sub node count."""
- subnode_count_map = defaultdict(int)
- for node in self._normal_node_map.values():
- if not node.scope:
- continue
-
- if not self._is_node_exist(node_name=node.scope):
- logger.warning("Can not find a scope node by the given name(%s), "
- "the name scope nodes may not have been created.", node.scope)
- continue
- subnode_count_map[node.scope] = subnode_count_map[node.scope] + 1
-
- for name, count in subnode_count_map.items():
- node = self._get_normal_node(node_name=name)
- node.subnode_count = count
-
- def _get_normal_node(self, node_id=None, node_name=None):
- """Query node by node id or node name."""
- if node_id is not None:
- name = self._node_id_map_name.get(node_id)
- node = self._normal_node_map.get(name)
- return node
-
- if node_name is not None:
- return self._normal_node_map.get(node_name)
-
- raise ParamMissError('Method requires an argument that is not None.')
-
- def _is_node_exist(self, node_id=None, node_name=None):
- """Check node is exist."""
- if node_id is not None:
- return bool(self._node_id_map_name.get(node_id))
-
- if node_name is not None:
- return bool(self._normal_node_map.get(node_name))
-
- raise ParamMissError('Method requires an argument that is not None.')
-
- @property
- def normal_node_count(self):
- """Get the normal node count."""
- return len(self._normal_node_map)
-
- def _cache_node(self, node):
- """Store the node in the cache."""
- # Notice:
- # The additional caching of Const and Parameter is to handle the Const and Parameter nodes separately later.
- if node.type == NodeTypeEnum.CONST.value:
- self._const_node_temp_cache.update({node.name: node})
- if node.type == NodeTypeEnum.PARAMETER.value:
- self._parameter_node_temp_cache.update({node.name: node})
-
- self._normal_node_map.update({node.name: node})
- self._node_id_map_name.update({node.node_id: node.name})
-
- def _delete_nodes_of_cache(self, node_names):
- """Delete node from cache."""
- logger.debug("These nodes will be removed from the cache, node names: %s.", str(node_names))
- for name in node_names:
-
- if self._parameter_node_temp_cache.get(name):
- self._parameter_node_temp_cache.pop(name)
- if self._const_node_temp_cache.get(name):
- self._const_node_temp_cache.pop(name)
-
- node = self._get_normal_node(node_name=name)
- self._normal_node_map.pop(name)
- self._node_id_map_name.pop(node.node_id)
-
- def _update_node_name_of_cache(self, node, new_name, update_parent=False):
- """
- Update a node name which is stored in cache.
-
- Args:
- node (Node): The node that will be renamed.
- new_name (str): The new name.
- update_parent (bool): Determines whether the input and output of the parent node need to be updated.
- """
- logger.debug('Update node name of cache, node(%s), new name is %s.', str(node), new_name)
- origin_name = node.name
- node.name = new_name
-
- # Find all nodes that need to modify the input and input
- update_node_map = {}
- for method in ['input', 'output', 'proxy_input', 'proxy_output']:
- for target_name in getattr(node, method):
- target_node = self._get_normal_node(node_name=target_name)
- if target_node is None:
- message = f"Node should not be None, name: {target_name}, {method}: {list(getattr(node, method))}."
- logger.error(message)
- continue
-
- update_node_map.update({target_name: target_node})
-
- if not update_parent:
- continue
-
- slash_index = target_name.find('/')
- while slash_index != -1:
- scope_name = target_name[:slash_index]
- slash_index = target_name.find('/', slash_index+1)
-
- if update_node_map.get(scope_name):
- continue
-
- scope_node = self._get_normal_node(node_name=scope_name)
- if scope_node is None:
- message = f"Can not find the scope node by scope name({scope_name}), " \
- f"may be this scope node has not been built."
- logger.debug(message)
- continue
-
- update_node_map.update({scope_name: scope_node})
-
- # Update the input and output of the nodes
- for target_node in update_node_map.values():
- for method in ['input', 'output', 'proxy_input', 'proxy_output']:
- attr_temp = getattr(target_node, method).get(origin_name)
- if attr_temp is None:
- # This method does not have this node, so it is skipped
- continue
-
- # Delete the old attribute and update new name to source node or destination node.
- getattr(target_node, f'delete_{method}')(origin_name)
- getattr(target_node, f'add_{method}')(new_name, attr_temp)
-
- # Delete the origin node in cache.
- self._delete_nodes_of_cache(node_names=[origin_name])
- self._cache_node(node)
-
- def _process_independent_layout(self):
- """Handle separate layout nodes."""
- independent_layout_node_map = {}
- for node in self._normal_node_map.values():
- base_name = node.name.split('/')[-1]
- if node.type == NodeTypeEnum.AGGREGATION_SCOPE.value and NodeTypeEnum.PARAMETER.value in base_name:
- independent_layout_node_map[node.name] = node
-
- # Find all sub nodes
- subnode_map = defaultdict(list)
- for node in self._normal_node_map.values():
- if independent_layout_node_map.get(node.scope):
- subnode_map[node.scope].append(node)
-
- # Notice:
- # The following processing is only done for the parameter node, other types of nodes are not processed.
- # Later, when you need to extend to other nodes, the code needs to be adjusted.
- for scope_node in independent_layout_node_map.values():
- scope_node.independent_layout = True
-
- method = 'output'
- for target_name, target_attr in dict(getattr(scope_node, method)).items():
- proxy_attr = dict(edge_type=target_attr['edge_type'])
-
- target_node = self._get_normal_node(node_name=target_name)
- getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr)
-
- # Note:
- # If the source node and the destination node are not in the same scope,
- # the proxy node is presented as scope in order to simplify the flow of the display data.
- # For example, the data flow is parameter[5]_1 -> add[5]_1/add1
- # we create a scope proxy node(add[5]_1) for parameter[5]_1,
- # so there is a proxy data flow parameter[5]_1 -> add[5]_1 instead of parameter[5]_1 -> add[5]_1/add1.
- if target_node.scope == scope_node.scope:
- getattr(scope_node, f'add_proxy_{method}')(target_name, proxy_attr)
- else:
- target_scope_node = self._get_normal_node(node_name=target_node.scope)
- getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr)
- getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr)
-
- for subnode in subnode_map[scope_node.name]:
- subnode.independent_layout = True
- for target_name, target_attr in dict(getattr(subnode, method)).items():
- proxy_attr = dict(edge_type=target_attr['edge_type'])
- target_node = self._get_normal_node(node_name=target_name)
- if target_node.scope == scope_node.scope:
- getattr(subnode, f'add_proxy_{method}')(target_name, proxy_attr)
- else:
- getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr)
-
- input_attr = getattr(target_node, 'input')[subnode.name]
- input_attr['independent_layout'] = True
- target_node.add_input(subnode.name, input_attr)
|