diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index d922c966..e67fc33a 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -619,11 +619,12 @@ class HierarchicalTree(Tree): cur_nd = self.get_node(p_nd[0]) return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) - def hash_key(self, node): + def hash_key(self, node, depth: int = 0): """ Generate hash key for each node. Args: + depth (int): Recursion depth. node (Node): Node. Returns: @@ -633,12 +634,12 @@ class HierarchicalTree(Tree): for s in node.successors(self.tree_identifier): cur_nd = self.get_node(s) if cur_nd.data.hash_key: - scsr_topo_order.append(cur_nd.data.hash_key) + scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]") continue if cur_nd.data.node_type in {NodeType.MODULE.value, NodeType.FUNC.value, NodeType.CLASS.value}: - scsr_topo_order.append(self.hash_key(cur_nd)) + scsr_topo_order.append(self.hash_key(cur_nd, depth + 1)) continue unique_key = "->".join(scsr_topo_order) node.data.hash_key = unique_key diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py index 13636b11..4386416c 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py @@ -20,6 +20,8 @@ from typing import List from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode +MAX_OUT_DEGREE = 1 + class CmpRelation: """Define cmp relation between `x` and `y`.""" @@ -123,7 +125,6 @@ class AlgorithmContext: context = AlgorithmContext() - __all__ = ["context", "gen_hash_key", "DagGraph"] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 22c3fe92..c2fa97de 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -17,7 +17,7 @@ import copy import uuid from typing import Dict, List, Callable, Union from collections import OrderedDict -from .common import context, gen_hash_key, DagGraph +from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE from ..third_party_graph.onnx_utils import OnnxNode, BaseNode scope_name_mapping = {} @@ -198,18 +198,19 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list], tuple[int, int], in degree and out degree. """ in_node = set() - out_node = set() node_in_seq = set() items = sequence if isinstance(sequence, list) else sequence.keys() - for _, item in enumerate(items): + for item in items: + node_in_seq.add(item.name if not isinstance(item, str) else item) + out_degree = 0 + for item in items: item = item.name if not isinstance(item, str) else item for ipt in dag.precursor_table[item]: in_node.add(ipt) for opt in dag.successor_table[item]: - out_node.add(opt) - node_in_seq.add(item) + if opt not in node_in_seq: + out_degree += 1 in_degree = len(in_node - node_in_seq) - out_degree = len(out_node - node_in_seq) return in_degree, out_degree @@ -335,6 +336,10 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, dag=dag) in_degree, out_degree = _get_pattern_degree(found_sequence, dag) + if out_degree > MAX_OUT_DEGREE: + cur_idx += 1 + continue + ptn = '->'.join(found_sequence.values()) ptn_key = f"{ptn}[{in_degree}, {out_degree}]" if ptn_key not in pattern: