From 9568fce57ef7fb8f051defb291ffb03b9fa461d0 Mon Sep 17 00:00:00 2001 From: liuchongming Date: Thu, 18 Feb 2021 16:47:42 +0800 Subject: [PATCH] Sub-graph searching and merging module support transformer arch. --- .../sub_graph_searcher/built_in_pattern.py | 51 +++++++++++++- .../sub_graph_searcher/common.py | 10 ++- .../sub_graph_searcher/known_module_name.py | 3 +- .../sub_graph_searcher/pattern.py | 1 + .../sub_graph_searcher/search_path.py | 66 ++++++++++++++----- .../sub_graph_searcher/searcher.py | 12 ++-- 6 files changed, 117 insertions(+), 26 deletions(-) diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index aafe7055..560a5016 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -16,6 +16,8 @@ __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name + from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern @@ -59,11 +61,16 @@ def register_pattern(ptn_name, in_degree, out_degree): def _reg(pattern): result = pattern() if not result: - return + return pattern + if ptn_name in BUILT_IN_PATTERN: + raise KeyError(f"{ptn_name} exists, `ptn_name` must be unique.") + BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), in_degree, out_degree, ptn_items=result) BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length) + BUILT_IN_PATTERN[ptn_name].ptn_name = ptn_name + return pattern return _reg @@ -112,3 +119,45 @@ def _up_sampling_in_op10(): return [ "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" ] + + +@register_pattern("Multi-Head-Attention", 2, 1) +@register_module_name("MultiHeadAttn", 2, 1) +def _multi_head_attention(): + return [ + "MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape", + "Transpose", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", + "MatMul", "Transpose", "Reshape", "MatMul", "Add" + ] + + +@register_pattern("Layer-Normalization", 1, 1) +@register_module_name("LayerNorm", 1, 1) +def _layer_norm(): + return [ + "ReduceMean", "Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" + ] + + +@register_pattern("Layer-Normalization-with-cast", 1, 1) +@register_module_name("LayerNorm", 1, 1) +def _layer_norm_with_cast(): + return [ + "ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" + ] + + +@register_pattern("GeLU", 1, 1) +@register_module_name("GeLU", 1, 1) +def _gelu(): + return [ + "Div", "Erf", "Add", "Mul", "Mul" + ] + + +@register_pattern("Linear", 1, 1) +@register_module_name("Linear", 1, 1) +def _linear(): + return [ + "MatMul", "Add" + ] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py index 0a5960e9..8551afc3 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py @@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util MAX_OUT_DEGREE = 1 MINI_FREQUENCY = 0.07 MAX_ITERATION_DEPTH = 16 -SATISFIED_SCORE = 1.5 +SATISFIED_SCORE = 0.74 ACCEPTABLE_RESULT_COUNT = 32 PTN_COVERAGE_THRESHOLD = 0.65 # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. @@ -126,6 +126,7 @@ class AlgorithmContext: node_collection = None precursor_table = {} successor_table = {} + outputs_table = {} def set_init_node_collection(self, nd_col): """Init node_collection.""" @@ -158,7 +159,12 @@ class AlgorithmContext: pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp), reverse=True) if len(pattern_arr) > AlgorithmContext.beam_width: - pattern_arr = pattern_arr[:self.beam_width] + new_pattern_arr = pattern_arr[:self.beam_width] + # Avoid dropping built-in pattern, because built-in patterns are much + # more potential. + for i in range(self.beam_width): + if pattern_arr[i][1].additional_score != 0: + new_pattern_arr.append(pattern_arr[i]) res = OrderedDict() for i, (key, ptn) in enumerate(pattern_arr): if ptn.count <= AlgorithmContext.MIN_FREQUENCY: diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py index 7947f60c..88e9abd9 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py @@ -49,10 +49,11 @@ def register_module_name(md_name: str, in_degree: int, out_degree: int): def _reg(pattern): result = pattern() if not result: - return + return pattern BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result), in_degree, out_degree, ptn_items=result)] = md_name + return pattern return _reg diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index 42a8079f..7a1f7018 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -25,6 +25,7 @@ class Pattern: self.start_index = [] self.end_index = [] self.module_name = None + self.ptn_name = "" self.ptn_length = pattern_length self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items self.in_degree = in_degree diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index ac706c8c..9278b464 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -15,14 +15,15 @@ """Declare search path related.""" import copy import uuid -from typing import Dict, List, Callable, Union from collections import OrderedDict +from typing import Dict, List, Callable, Union + +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ + is_built_in_pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ MAX_OUT_DEGREE, cal_matching_score from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping -from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ - is_built_in_pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ pattern_fuzzy_matching from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode @@ -80,10 +81,13 @@ def _is_valid_pattern(pattern, dag): """ if not pattern: return False - first_op = dag.node_collection[list(pattern.keys())[0]].op_type + head = dag.node_collection[list(pattern.keys())[0]] + op_type = head.op_type + if isinstance(head, MergedONNXNode) and "LayerNorm" in head.known_module_name: + return False if len(pattern) == 1: return False - if first_op in OptimizeRules.CAN_NOT_BE_HEAD: + if op_type in OptimizeRules.CAN_NOT_BE_HEAD: return False return True @@ -133,9 +137,12 @@ def random_name(module_name): class MergedONNXNode(BaseNode): """Define merged onnx node.""" - def __init__(self, name, module_name, ori_nodes): + def __init__(self, name, module_name, ori_nodes, inputs, outputs, known_module_name): super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name) self.nodes = ori_nodes + self.inputs = inputs + self.outputs = outputs + self.known_module_name = known_module_name if known_module_name else "" def get_name(self): return self.name @@ -192,23 +199,30 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list], dag (DagGraph): Graph instance. Returns: - tuple[int, int], in degree and out degree. + tuple[int, int, set, set], in degree, out degree, precursors and successors. """ in_node = set() node_in_seq = set() items = sequence if isinstance(sequence, list) else sequence.keys() for item in items: node_in_seq.add(item.name if not isinstance(item, str) else item) - out_degree = 0 + out_node = set() for item in items: item = item.name if not isinstance(item, str) else item for ipt in dag.precursor_table[item]: in_node.add(ipt) for opt in dag.successor_table[item]: if opt not in node_in_seq: - out_degree += 1 + if item not in context.outputs_table: + out = dag.node_collection[item].outputs + else: + out = set(context.outputs_table[item]) + out_node.update(out) in_degree = len(in_node - node_in_seq) - return in_degree, out_degree + # Because only record nodes that outputs are referred by other nodes, + # the outputs number of each node must be calculated. + out_degree = len(out_node) + return in_degree, out_degree, in_node - node_in_seq, out_node def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph): @@ -312,14 +326,14 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str if not matched: cur_idx += 1 continue - in_degree, out_degree = _get_pattern_degree(init_pattern, dag) + in_degree, out_degree, _, _ = _get_pattern_degree(init_pattern, dag) if in_degree != BUILT_IN_PATTERN[k].in_degree or out_degree != BUILT_IN_PATTERN[k].out_degree: cur_idx += 1 continue ptn_key = f"{BUILT_IN_PATTERN[k].pattern}" \ f"[{BUILT_IN_PATTERN[k].in_degree}, {BUILT_IN_PATTERN[k].out_degree}]" if ptn_key not in pattern: - pattern[ptn_key] = BUILT_IN_PATTERN[k] + pattern[ptn_key] = copy.deepcopy(BUILT_IN_PATTERN[k]) pattern[ptn_key].insert(cur_idx, ptn_len) cur_idx = cur_idx + 1 @@ -375,8 +389,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, offset=cur_idx - jump_step, dag=dag) - in_degree, out_degree = _get_pattern_degree(found_sequence, dag) - if out_degree > MAX_OUT_DEGREE or in_degree > MAX_OUT_DEGREE: + in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) + if out_degree > MAX_OUT_DEGREE: cur_idx += 1 continue @@ -391,9 +405,24 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) cur_idx = cur_idx + 1 + pattern = _post_process_overlap(pattern) return pattern +def _post_process_overlap(patterns) -> Dict: + """Post process overlap of found pattern.""" + for name in patterns: + prev_end = patterns[name].end_index[0] + idx = 1 + while idx < len(patterns[name].end_index): + if patterns[name].start_index[idx] <= prev_end: + patterns[name].start_index.pop(idx) + patterns[name].end_index.pop(idx) + continue + idx += 1 + return patterns + + class SearchPath: """ Use SearchPath to store the search path. @@ -510,7 +539,7 @@ class SearchPath: path_length += 1 continue - in_degree, out_degree = _get_pattern_degree(visited_node, self.graph) + in_degree, out_degree, inputs, outputs = _get_pattern_degree(visited_node, self.graph) if in_degree != pattern.in_degree or out_degree != pattern.out_degree: topo_order.extend(visited_node) index += j + 1 @@ -520,7 +549,10 @@ class SearchPath: inverted_index[path_length] = [j + index for j in range(pattern_len)] new_node = MergedONNXNode(name=random_name(pattern.module_name), module_name=pattern.module_name, - ori_nodes=visited_node[:]) + ori_nodes=visited_node[:], + inputs=inputs, + outputs=outputs, + known_module_name=pattern.known_module_name) self._reconnect(new_node) self.graph.node_collection[new_node.name] = new_node topo_order.append(new_node) @@ -625,7 +657,7 @@ class SearchPath: def _actual_val(self): """Calculate ground-truth score of the path.""" bonus = self._cal_bonus() - return bonus * 0.1 + self.replacement_ratio * 0.9 + return (bonus + self.replacement_ratio) / 2 def __lt__(self, other): """Override `<` operator.""" diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index c370ada0..cde2c8fd 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -39,9 +39,10 @@ def _is_satisfied(path): """ if len(path.recursion_path) > MAX_ITERATION_DEPTH: return True - if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]): + candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]) + if not path.new_pattern or not candidate_eval: return True - if path.evaluate_score() > SATISFIED_SCORE: + if path.evaluate_score() > SATISFIED_SCORE and not candidate_eval: return True return False @@ -92,6 +93,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], if _is_satisfied(cur_path): available_path.append(cur_path) deduplicate_path.add(cur_path.hash_of_aft_repl) + _is_satisfied(cur_path) continue if len(available_path) >= ACCEPTABLE_RESULT_COUNT: @@ -154,6 +156,7 @@ def _retrieve_scope_name(found_path): Args: found_path: Found path. """ + _add_known_module_name(found_path) module_name_mgr = dict() module_dict = dict() @@ -257,6 +260,7 @@ def _build_connection(loader): for node_name, node in loader.nodes_dict.items(): context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys()) + context.outputs_table[node_name] = node.output_name_list dag = DagGraph(nodes=context.node_collection.copy(), precursor=context.precursor_table.copy(), @@ -308,6 +312,7 @@ def _add_known_module_name(search_path): for it in search_path.recursion_path: if it.pattern.known_module_name: ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name + return ctx @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") @@ -329,9 +334,6 @@ def generate_scope_name(data_loader): if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): topo_order_with_scope_name_list = flatten_graph(init_dag) - if result: - _add_known_module_name(result) - except (ValueError, IndexError, AttributeError, KeyError) as _: topo_order_with_scope_name_list = flatten_graph(init_dag)