diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index 560a5016..1a7f3dfb 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -131,6 +131,26 @@ def _multi_head_attention(): ] +@register_pattern("Multi-Head-Attention-1", 2, 1) +@register_module_name("MultiHeadAttn", 2, 1) +def _multi_head_attention_1(): + return [ + "MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape", + "Transpose", "Reshape", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", + "Softmax", "MatMul", "Transpose", "Reshape", "MatMul", "Add" + ] + + +@register_pattern("Multi-Head-Attention-with-Einsum", 2, 1) +@register_module_name("MultiHeadAttn", 2, 1) +def _multi_head_attention_with_einsum(): + return [ + "MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape", + "Transpose", "Reshape", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", + "MatMul", "Transpose", "Einsum", "Add" + ] + + @register_pattern("Layer-Normalization", 1, 1) @register_module_name("LayerNorm", 1, 1) def _layer_norm(): @@ -161,3 +181,11 @@ def _linear(): return [ "MatMul", "Add" ] + + +@register_pattern("New-GeLU", 1, 1) +@register_module_name("NewGeLU", 1, 1) +def _new_gelu(): + return [ + "Mul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul" + ] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index 8ecefb71..d92b3bf8 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== """Definition of search entry.""" -from queue import PriorityQueue +from queue import PriorityQueue, Queue from typing import Dict, List +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ + pattern_fuzzy_matching from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ @@ -196,6 +198,109 @@ def _scope_name_deduplication(key, scope_names, memo) -> list: return result +def _is_attn_layer(split_module): + """ + Whether the submodule is attention layer. + + Attention layer is defined as: attn-add-norm-fc-gelu-fc-add-norm. + + Args: + split_module (list[list[str]]): Operations list in module. + + Returns: + list, found module name. + """ + + def _matched(modules): + """If the similarity score of sub_module and attention pattern is greater than 0.95, take it.""" + threshold = 0.95 + leaf_node = [m[-1] for m in modules] + attn_layer_ptn_with_gelu = [ + "MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape", "Transpose", "Reshape", + "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Reshape", "MatMul", + "Add", "Add", "ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add", + "MatMul", "Add", "Div", "Erf", "Add", "Mul", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub", "Cast", + "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" + ] + attn_layer_ptn_with_new_gelu = [ + "MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape", "Transpose", "Reshape", "Reshape", + "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Einsum", "Add", "Add", + "ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add", "MatMul", "Add", + "Mul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub", + "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" + ] + matched = max(pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_gelu)[1], + pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_new_gelu)[1]) > threshold + return matched + + candidates = Queue() + candidates.put(split_module, block=False) + while not candidates.empty(): + candidate = candidates.get(block=False) + if _matched(candidate): + return candidate[0][0].split("_")[0] + cur_scope = candidate[0][1] + split_sub_module = [] + for item in candidate: + # It's not necessary to scan the module which depth is 2. + if len(item) == 2: + continue + if item[1] != cur_scope: + cur_scope = item[1] + if split_sub_module: + candidates.put(split_sub_module[:], block=False) + split_sub_module.clear() + split_sub_module.append(item[1:]) + if split_sub_module: + candidates.put(split_sub_module[:], block=False) + return None + + +def _lift_each_module(sub_module): + """Lift each module in sub-module.""" + lifted_module = [] + split_module = [] + cur_scope = sub_module[0].split("/")[0] + segmented_pos = 0 + + def _lift(modules): + nonlocal lifted_module, split_module + exceed_max_depth = max(*[len(m.split("/")) for m in modules]) > 2 + if not exceed_max_depth: + for _ in range(len(split_module)): + lifted_module.append((False, 0)) + return + # attn_module_name has been normalized without "_idx", only has raw module name. + attn_module_name = _is_attn_layer(split_module) + for s_md in split_module: + if attn_module_name: + md_name = [md for md in s_md if attn_module_name in md] + if md_name: + md_name = md_name[0] + attn_idx = s_md.index(md_name) + if attn_idx > 0: + lifted_module.append((True, attn_idx)) + continue + lifted_module.append((False, 0)) + continue + lifted_module.append((True, 0)) + + for i, m in enumerate(sub_module): + split_md = m.split("/") + # Find one module. + if cur_scope != split_md[0]: + _lift(sub_module[segmented_pos:i]) + # Clean up. + cur_scope = split_md[0] + segmented_pos = i + split_module.clear() + split_module.append(split_md) + + # Do lift on last module. + _lift(sub_module[segmented_pos:]) + return lifted_module + + def _retrieve_operators(module_path, module_dict): """ Retrieve operators from path. @@ -208,26 +313,29 @@ def _retrieve_operators(module_path, module_dict): str: module_name, operators in module. """ - def _whether_to_lift(sub_module): - """Whether to lift a scope according to its depth.""" - return max(*[len(m.split("/")) for m in sub_module]) > 2 - def _lift(sub_module): """Lift nodes upper.""" nonlocal added_module lifted_submodule = [] record = dict() - lift_needed = _whether_to_lift(sub_module) - for m in sub_module: + # DO NOT lift on attn-add-norm-fc with GeLU-fc-add-norm. + # It's a fix pattern in Transformer model. + lift_on_each_module = _lift_each_module(sub_module) + for i, m in enumerate(sub_module): + lift_needed, lift_from = lift_on_each_module[i] scopes = m.split("/") - if lift_needed and len(scopes) == 3: + if lift_needed and len(scopes) >= 3: # If the scope depth is 3, like ModuleX/ModuleY/Gemm, # then we lift ModuleY to top level. - md_name, md_idx = scopes[-2].split("_") + md_name, md_idx = scopes[-2 if lift_from == 0 else lift_from].split("_") if record.get(md_name, -1) != md_idx: record[md_name] = md_idx added_module[md_name] = added_module.setdefault(md_name, -1) + 1 - lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") + if lift_from != 0: + lifted_md = "/".join([f"{md_name}_{added_module.setdefault(md_name, 0)}"] + scopes[lift_from + 1:]) + else: + lifted_md = f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}" + lifted_submodule.append(lifted_md) continue if lift_needed and len(scopes) == 2: # If the module is required to lifted, then lift leaf node to parent. @@ -263,7 +371,6 @@ def _build_connection(loader): context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.outputs_table[node_name] = node.output_name_list - # Record the model inputs count, use it to control the search algorithm. context.has_multi_inputs = len(loader.input_nodes) > 1 dag = DagGraph(nodes=context.node_collection.copy(), diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index ef1c3b85..ef0e033c 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -192,7 +192,6 @@ class Graph(BaseGraph, abc.ABC): """ for name, node in self._nodes_collection.items(): if node.in_degree == 0: - # NOTICE: what's usage of `scope`? self._input_nodes.append(name) if node.out_degree == 0: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py index f29fc944..390185f2 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py @@ -107,8 +107,16 @@ class OnnxSimplify: output_nodes_name = list() for node in self._constant_nodes: output_nodes_name.extend(node.output) + original_outputs = [nd.name for nd in self._onnx_model.graph.output] self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name) + idx = 0 + while idx < len(self._onnx_model.graph.output): + cur_opt = self._onnx_model.graph.output[idx] + if cur_opt.name not in original_outputs: + self._onnx_model.graph.output.remove(cur_opt) + continue + idx += 1 def _replace_constant_nodes(self): """Replace constant nodes to nodes with op_type 'Constant'."""