From 1afeca201580e139b23b2d7c670fbf6602b4812b Mon Sep 17 00:00:00 2001 From: liuchongming Date: Wed, 13 Jan 2021 16:42:10 +0800 Subject: [PATCH] Optimize sub-graph searcher. --- .../sub_graph_searcher/common.py | 63 ++++++++------- .../sub_graph_searcher/pattern.py | 4 +- .../sub_graph_searcher/searcher.py | 76 ++++++++++++++----- 3 files changed, 94 insertions(+), 49 deletions(-) diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py index f66b2bc2..0a5960e9 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,10 +33,10 @@ from typing import List from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode MAX_OUT_DEGREE = 1 -MINI_FREQUENCY = 4 -MAX_ITERATION_DEPTH = 4 -SATISFIED_SCORE = 0.6 -ACCEPTABLE_RESULT_COUNT = 16 +MINI_FREQUENCY = 0.07 +MAX_ITERATION_DEPTH = 16 +SATISFIED_SCORE = 1.5 +ACCEPTABLE_RESULT_COUNT = 32 PTN_COVERAGE_THRESHOLD = 0.65 # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. IGNORE_PTN_LEN = 5 @@ -52,6 +52,23 @@ def cal_matching_score(sequence_len: int): return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1 +def _cmp(x, y): + """Cmp function to sort pattern.""" + if x[1].count > y[1].count: + return CmpRelation.GREATER + if x[1].count < y[1].count: + return CmpRelation.LESS + if x[1].additional_score > y[1].additional_score: + return CmpRelation.GREATER + if x[1].additional_score < y[1].additional_score: + return CmpRelation.LESS + if x[1].ptn_length > y[1].ptn_length: + return CmpRelation.GREATER + if x[1].ptn_length < y[1].ptn_length: + return CmpRelation.LESS + return CmpRelation.EQUAL + + class CmpRelation: """Define cmp relation between `x` and `y`.""" # When x is equal to y in logic. @@ -138,39 +155,29 @@ class AlgorithmContext: Returns: OrderedDict, sorted pattern. """ - - def _cmp(x, y): - """Cmp function to sort pattern.""" - if x[1].count > y[1].count: - return CmpRelation.GREATER - if x[1].count < y[1].count: - return CmpRelation.LESS - if x[1].additional_score > y[1].additional_score: - return CmpRelation.GREATER - if x[1].additional_score < y[1].additional_score: - return CmpRelation.LESS - if x[1].ptn_length > y[1].ptn_length: - return CmpRelation.GREATER - if x[1].ptn_length < y[1].ptn_length: - return CmpRelation.LESS - return CmpRelation.EQUAL - pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp), reverse=True) - if len(pattern_arr) > self.beam_width: + if len(pattern_arr) > AlgorithmContext.beam_width: pattern_arr = pattern_arr[:self.beam_width] res = OrderedDict() for i, (key, ptn) in enumerate(pattern_arr): - if ptn.count <= self.MIN_FREQUENCY: + if ptn.count <= AlgorithmContext.MIN_FREQUENCY: + continue + if ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN: + res[key] = ptn continue skip = False for j, (_, candidate) in enumerate(pattern_arr): - if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN): + if i == j: continue + # If `ptn` is a sub-pattern of `candidate`, and `ptn` count equals to `candidate`, + # then reject the `ptn`. if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \ and ptn.pattern in candidate.pattern: skip = True break + # If `candidate` is sub-pattern of `ptn`, `candidate` has additional score, + # and `ptn` has no additional score, then calculate its replacement ratio. if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \ and ptn.additional_score == 0 and candidate.pattern in ptn.pattern: ratio = candidate.ptn_length / ptn.ptn_length @@ -178,9 +185,9 @@ class AlgorithmContext: skip = True break - if skip: - continue - res[key] = ptn + if not skip: + res[key] = ptn + return res diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index 27ca4426..42a8079f 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ class Pattern: # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # the pattern will get additional score. self.additional_score = 0 - self.know_module_name = None + self.known_module_name = None def insert(self, idx, seq_len): """ diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index 0e812272..d61c8df0 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,13 +36,23 @@ def _is_satisfied(path): """ if len(path.recursion_path) > MAX_ITERATION_DEPTH: return True - if not path.new_pattern or max([p.count for _, p in path.new_pattern.items()]) < MINI_FREQUENCY: + if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]): return True if path.evaluate_score() > SATISFIED_SCORE: return True return False +def is_pattern_satisfied(pattern, seq): + """Whether a pattern is valid.""" + rpl_ratio = 1.0 * pattern.count * pattern.ptn_length / len(seq.topo_order_aft_repl) + # If replacement ratio is larger than 7%, + # then take it, otherwise, reject this pattern. + if rpl_ratio >= MINI_FREQUENCY: + return True + return False + + def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], init_graph, sub_graph_size: int = 2) -> List[SearchPath]: """ @@ -61,7 +71,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], sorted_pattern = context.sort_with_beam(init_pattern) # 2. Put pattern into queue. queue = PriorityQueue() - for _, pattern_inst in sorted_pattern.items(): + for pattern_inst in sorted_pattern.values(): queue.put( SearchPath(pattern=pattern_inst, sequence=init_topo_order, graph=init_graph, @@ -70,6 +80,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], ) available_path = [] + deduplicate_path = set() while not queue.empty(): # a. replace pattern in current topo order. cur_path = queue.get(block=False) @@ -77,13 +88,18 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], # b. generate new pattern based on replaced topo order. if _is_satisfied(cur_path): available_path.append(cur_path) + deduplicate_path.add(cur_path.hash_of_aft_repl) continue if len(available_path) >= ACCEPTABLE_RESULT_COUNT: break - for _, cur_pattern in cur_path.new_pattern.items(): - if cur_pattern.count < MINI_FREQUENCY: + for cur_pattern in cur_path.new_pattern.values(): + if not is_pattern_satisfied(cur_pattern, cur_path): + if cur_path.hash_of_aft_repl in deduplicate_path: + continue + available_path.append(cur_path) + deduplicate_path.add(cur_path.hash_of_aft_repl) continue key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]", gen_hash_key(cur_topo_order, without_module=True)]) @@ -167,13 +183,8 @@ def _scope_name_deduplication(key, scope_names, memo) -> list: Returns: list, renamed scope name. """ - result = [] - if key not in memo: - memo[key] = 0 - for item in scope_names: - item = item.replace(key, f"{key}_{memo.get(key)}") - result.append(item) - memo[key] += 1 + memo[key] = memo.setdefault(key, -1) + 1 + result = [item.replace(key, f"{key}_{memo.get(key)}") for item in scope_names] return result @@ -188,14 +199,43 @@ def _retrieve_operators(module_path, module_dict): Returns: str: module_name, operators in module. """ + + def _whether_to_lift(sub_module): + """Whether to lift a scope according to its depth.""" + return max(*[len(m.split("/")) for m in sub_module]) > 2 + + def _lift(sub_module): + """Lift nodes upper.""" + nonlocal added_module + lifted_submodule = [] + continuity_idx = -1 + lift_needed = _whether_to_lift(sub_module) + for m in sub_module: + scopes = m.split("/") + if lift_needed and len(scopes) == 3: + # If the scope depth is 3, like ModuleX/ModuleY/Gemm, + # then we lift ModuleY to top level. + md_name, md_idx = scopes[-2].split("_") + if continuity_idx != int(md_idx): + continuity_idx = int(md_idx) + added_module[md_name] = added_module.setdefault(md_name, -1) + 1 + lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") + continue + if lift_needed and len(scopes) == 2: + # If the module is required to lifted, then lift leaf node to parent. + lifted_submodule.append(scopes[-1]) + continue + # If lift is not required, then add it directly. + lifted_submodule.append(m) + return lifted_submodule + added_module = dict() node_in_pattern = module_path.pattern.ptn_items node_list = [] for node in node_in_pattern: if module_dict.get(node): - node_list += _scope_name_deduplication(node, - module_dict[node], - added_module) + sub_scope = _scope_name_deduplication(node, module_dict[node], added_module) + node_list += _lift(sub_scope) else: node_list.append(node) val = [f"{module_path.pattern.module_name}/{node}" for node in node_list] @@ -231,7 +271,7 @@ def flatten_graph(graph): Returns: list[str], corresponding scope name. """ - return [f"Model/{node.op_type}" for _, node in graph.node_collection.items()] + return [f"Model/{node.op_type}" for node in graph.node_collection.values()] def validate_topo_order_succession(): @@ -280,9 +320,6 @@ def generate_scope_name(data_loader): """ init_dag = _build_connection(data_loader) try: - if not validate_topo_order_succession(): - raise ValueError("Topological order is not successive.") - result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) @@ -294,4 +331,5 @@ def generate_scope_name(data_loader): except (ValueError, IndexError, AttributeError, KeyError) as _: topo_order_with_scope_name_list = flatten_graph(init_dag) + return topo_order_with_scope_name_list