|
|
@@ -13,9 +13,11 @@ |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
# ============================================================================== |
|
|
# ============================================================================== |
|
|
"""Definition of search entry.""" |
|
|
"""Definition of search entry.""" |
|
|
from queue import PriorityQueue |
|
|
|
|
|
|
|
|
from queue import PriorityQueue, Queue |
|
|
from typing import Dict, List |
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ |
|
|
|
|
|
pattern_fuzzy_matching |
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ |
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ |
|
|
ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT |
|
|
ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT |
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ |
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ |
|
|
@@ -196,6 +198,109 @@ def _scope_name_deduplication(key, scope_names, memo) -> list: |
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_attn_layer(split_module): |
|
|
|
|
|
""" |
|
|
|
|
|
Whether the submodule is attention layer. |
|
|
|
|
|
|
|
|
|
|
|
Attention layer is defined as: attn-add-norm-fc-gelu-fc-add-norm. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
split_module (list[list[str]]): Operations list in module. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
list, found module name. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def _matched(modules): |
|
|
|
|
|
"""If the similarity score of sub_module and attention pattern is greater than 0.95, take it.""" |
|
|
|
|
|
threshold = 0.95 |
|
|
|
|
|
leaf_node = [m[-1] for m in modules] |
|
|
|
|
|
attn_layer_ptn_with_gelu = [ |
|
|
|
|
|
"MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape", "Transpose", "Reshape", |
|
|
|
|
|
"Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Reshape", "MatMul", |
|
|
|
|
|
"Add", "Add", "ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add", |
|
|
|
|
|
"MatMul", "Add", "Div", "Erf", "Add", "Mul", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub", "Cast", |
|
|
|
|
|
"Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" |
|
|
|
|
|
] |
|
|
|
|
|
attn_layer_ptn_with_new_gelu = [ |
|
|
|
|
|
"MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape", "Transpose", "Reshape", "Reshape", |
|
|
|
|
|
"Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Einsum", "Add", "Add", |
|
|
|
|
|
"ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add", "MatMul", "Add", |
|
|
|
|
|
"Mul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub", |
|
|
|
|
|
"Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add" |
|
|
|
|
|
] |
|
|
|
|
|
matched = max(pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_gelu)[1], |
|
|
|
|
|
pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_new_gelu)[1]) > threshold |
|
|
|
|
|
return matched |
|
|
|
|
|
|
|
|
|
|
|
candidates = Queue() |
|
|
|
|
|
candidates.put(split_module, block=False) |
|
|
|
|
|
while not candidates.empty(): |
|
|
|
|
|
candidate = candidates.get(block=False) |
|
|
|
|
|
if _matched(candidate): |
|
|
|
|
|
return candidate[0][0].split("_")[0] |
|
|
|
|
|
cur_scope = candidate[0][1] |
|
|
|
|
|
split_sub_module = [] |
|
|
|
|
|
for item in candidate: |
|
|
|
|
|
# It's not necessary to scan the module which depth is 2. |
|
|
|
|
|
if len(item) == 2: |
|
|
|
|
|
continue |
|
|
|
|
|
if item[1] != cur_scope: |
|
|
|
|
|
cur_scope = item[1] |
|
|
|
|
|
if split_sub_module: |
|
|
|
|
|
candidates.put(split_sub_module[:], block=False) |
|
|
|
|
|
split_sub_module.clear() |
|
|
|
|
|
split_sub_module.append(item[1:]) |
|
|
|
|
|
if split_sub_module: |
|
|
|
|
|
candidates.put(split_sub_module[:], block=False) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lift_each_module(sub_module): |
|
|
|
|
|
"""Lift each module in sub-module.""" |
|
|
|
|
|
lifted_module = [] |
|
|
|
|
|
split_module = [] |
|
|
|
|
|
cur_scope = sub_module[0].split("/")[0] |
|
|
|
|
|
segmented_pos = 0 |
|
|
|
|
|
|
|
|
|
|
|
def _lift(modules): |
|
|
|
|
|
nonlocal lifted_module, split_module |
|
|
|
|
|
exceed_max_depth = max(*[len(m.split("/")) for m in modules]) > 2 |
|
|
|
|
|
if not exceed_max_depth: |
|
|
|
|
|
for _ in range(len(split_module)): |
|
|
|
|
|
lifted_module.append((False, 0)) |
|
|
|
|
|
return |
|
|
|
|
|
# attn_module_name has been normalized without "_idx", only has raw module name. |
|
|
|
|
|
attn_module_name = _is_attn_layer(split_module) |
|
|
|
|
|
for s_md in split_module: |
|
|
|
|
|
if attn_module_name: |
|
|
|
|
|
md_name = [md for md in s_md if attn_module_name in md] |
|
|
|
|
|
if md_name: |
|
|
|
|
|
md_name = md_name[0] |
|
|
|
|
|
attn_idx = s_md.index(md_name) |
|
|
|
|
|
if attn_idx > 0: |
|
|
|
|
|
lifted_module.append((True, attn_idx)) |
|
|
|
|
|
continue |
|
|
|
|
|
lifted_module.append((False, 0)) |
|
|
|
|
|
continue |
|
|
|
|
|
lifted_module.append((True, 0)) |
|
|
|
|
|
|
|
|
|
|
|
for i, m in enumerate(sub_module): |
|
|
|
|
|
split_md = m.split("/") |
|
|
|
|
|
# Find one module. |
|
|
|
|
|
if cur_scope != split_md[0]: |
|
|
|
|
|
_lift(sub_module[segmented_pos:i]) |
|
|
|
|
|
# Clean up. |
|
|
|
|
|
cur_scope = split_md[0] |
|
|
|
|
|
segmented_pos = i |
|
|
|
|
|
split_module.clear() |
|
|
|
|
|
split_module.append(split_md) |
|
|
|
|
|
|
|
|
|
|
|
# Do lift on last module. |
|
|
|
|
|
_lift(sub_module[segmented_pos:]) |
|
|
|
|
|
return lifted_module |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _retrieve_operators(module_path, module_dict): |
|
|
def _retrieve_operators(module_path, module_dict): |
|
|
""" |
|
|
""" |
|
|
Retrieve operators from path. |
|
|
Retrieve operators from path. |
|
|
@@ -208,26 +313,29 @@ def _retrieve_operators(module_path, module_dict): |
|
|
str: module_name, operators in module. |
|
|
str: module_name, operators in module. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def _whether_to_lift(sub_module): |
|
|
|
|
|
"""Whether to lift a scope according to its depth.""" |
|
|
|
|
|
return max(*[len(m.split("/")) for m in sub_module]) > 2 |
|
|
|
|
|
|
|
|
|
|
|
def _lift(sub_module): |
|
|
def _lift(sub_module): |
|
|
"""Lift nodes upper.""" |
|
|
"""Lift nodes upper.""" |
|
|
nonlocal added_module |
|
|
nonlocal added_module |
|
|
lifted_submodule = [] |
|
|
lifted_submodule = [] |
|
|
record = dict() |
|
|
record = dict() |
|
|
lift_needed = _whether_to_lift(sub_module) |
|
|
|
|
|
for m in sub_module: |
|
|
|
|
|
|
|
|
# DO NOT lift on attn-add-norm-fc with GeLU-fc-add-norm. |
|
|
|
|
|
# It's a fix pattern in Transformer model. |
|
|
|
|
|
lift_on_each_module = _lift_each_module(sub_module) |
|
|
|
|
|
for i, m in enumerate(sub_module): |
|
|
|
|
|
lift_needed, lift_from = lift_on_each_module[i] |
|
|
scopes = m.split("/") |
|
|
scopes = m.split("/") |
|
|
if lift_needed and len(scopes) == 3: |
|
|
|
|
|
|
|
|
if lift_needed and len(scopes) >= 3: |
|
|
# If the scope depth is 3, like ModuleX/ModuleY/Gemm, |
|
|
# If the scope depth is 3, like ModuleX/ModuleY/Gemm, |
|
|
# then we lift ModuleY to top level. |
|
|
# then we lift ModuleY to top level. |
|
|
md_name, md_idx = scopes[-2].split("_") |
|
|
|
|
|
|
|
|
md_name, md_idx = scopes[-2 if lift_from == 0 else lift_from].split("_") |
|
|
if record.get(md_name, -1) != md_idx: |
|
|
if record.get(md_name, -1) != md_idx: |
|
|
record[md_name] = md_idx |
|
|
record[md_name] = md_idx |
|
|
added_module[md_name] = added_module.setdefault(md_name, -1) + 1 |
|
|
added_module[md_name] = added_module.setdefault(md_name, -1) + 1 |
|
|
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") |
|
|
|
|
|
|
|
|
if lift_from != 0: |
|
|
|
|
|
lifted_md = "/".join([f"{md_name}_{added_module.setdefault(md_name, 0)}"] + scopes[lift_from + 1:]) |
|
|
|
|
|
else: |
|
|
|
|
|
lifted_md = f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}" |
|
|
|
|
|
lifted_submodule.append(lifted_md) |
|
|
continue |
|
|
continue |
|
|
if lift_needed and len(scopes) == 2: |
|
|
if lift_needed and len(scopes) == 2: |
|
|
# If the module is required to lifted, then lift leaf node to parent. |
|
|
# If the module is required to lifted, then lift leaf node to parent. |
|
|
@@ -263,7 +371,6 @@ def _build_connection(loader): |
|
|
context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) |
|
|
context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) |
|
|
context.successor_table[node_name] = list(node.get_successor_dict().keys()) |
|
|
context.successor_table[node_name] = list(node.get_successor_dict().keys()) |
|
|
context.outputs_table[node_name] = node.output_name_list |
|
|
context.outputs_table[node_name] = node.output_name_list |
|
|
|
|
|
|
|
|
# Record the model inputs count, use it to control the search algorithm. |
|
|
# Record the model inputs count, use it to control the search algorithm. |
|
|
context.has_multi_inputs = len(loader.input_nodes) > 1 |
|
|
context.has_multi_inputs = len(loader.input_nodes) > 1 |
|
|
dag = DagGraph(nodes=context.node_collection.copy(), |
|
|
dag = DagGraph(nodes=context.node_collection.copy(), |
|
|
|