|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. |
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -36,13 +36,23 @@ def _is_satisfied(path): |
|
|
|
""" |
|
|
|
if len(path.recursion_path) > MAX_ITERATION_DEPTH: |
|
|
|
return True |
|
|
|
if not path.new_pattern or max([p.count for _, p in path.new_pattern.items()]) < MINI_FREQUENCY: |
|
|
|
if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]): |
|
|
|
return True |
|
|
|
if path.evaluate_score() > SATISFIED_SCORE: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def is_pattern_satisfied(pattern, seq): |
|
|
|
"""Whether a pattern is valid.""" |
|
|
|
rpl_ratio = 1.0 * pattern.count * pattern.ptn_length / len(seq.topo_order_aft_repl) |
|
|
|
# If replacement ratio is larger than 7%, |
|
|
|
# then take it, otherwise, reject this pattern. |
|
|
|
if rpl_ratio >= MINI_FREQUENCY: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], |
|
|
|
init_graph, sub_graph_size: int = 2) -> List[SearchPath]: |
|
|
|
""" |
|
|
|
@@ -61,7 +71,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], |
|
|
|
sorted_pattern = context.sort_with_beam(init_pattern) |
|
|
|
# 2. Put pattern into queue. |
|
|
|
queue = PriorityQueue() |
|
|
|
for _, pattern_inst in sorted_pattern.items(): |
|
|
|
for pattern_inst in sorted_pattern.values(): |
|
|
|
queue.put( |
|
|
|
SearchPath(pattern=pattern_inst, sequence=init_topo_order, |
|
|
|
graph=init_graph, |
|
|
|
@@ -70,6 +80,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], |
|
|
|
) |
|
|
|
|
|
|
|
available_path = [] |
|
|
|
deduplicate_path = set() |
|
|
|
while not queue.empty(): |
|
|
|
# a. replace pattern in current topo order. |
|
|
|
cur_path = queue.get(block=False) |
|
|
|
@@ -77,13 +88,18 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], |
|
|
|
# b. generate new pattern based on replaced topo order. |
|
|
|
if _is_satisfied(cur_path): |
|
|
|
available_path.append(cur_path) |
|
|
|
deduplicate_path.add(cur_path.hash_of_aft_repl) |
|
|
|
continue |
|
|
|
|
|
|
|
if len(available_path) >= ACCEPTABLE_RESULT_COUNT: |
|
|
|
break |
|
|
|
|
|
|
|
for _, cur_pattern in cur_path.new_pattern.items(): |
|
|
|
if cur_pattern.count < MINI_FREQUENCY: |
|
|
|
for cur_pattern in cur_path.new_pattern.values(): |
|
|
|
if not is_pattern_satisfied(cur_pattern, cur_path): |
|
|
|
if cur_path.hash_of_aft_repl in deduplicate_path: |
|
|
|
continue |
|
|
|
available_path.append(cur_path) |
|
|
|
deduplicate_path.add(cur_path.hash_of_aft_repl) |
|
|
|
continue |
|
|
|
key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]", |
|
|
|
gen_hash_key(cur_topo_order, without_module=True)]) |
|
|
|
@@ -167,13 +183,8 @@ def _scope_name_deduplication(key, scope_names, memo) -> list: |
|
|
|
Returns: |
|
|
|
list, renamed scope name. |
|
|
|
""" |
|
|
|
result = [] |
|
|
|
if key not in memo: |
|
|
|
memo[key] = 0 |
|
|
|
for item in scope_names: |
|
|
|
item = item.replace(key, f"{key}_{memo.get(key)}") |
|
|
|
result.append(item) |
|
|
|
memo[key] += 1 |
|
|
|
memo[key] = memo.setdefault(key, -1) + 1 |
|
|
|
result = [item.replace(key, f"{key}_{memo.get(key)}") for item in scope_names] |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@@ -188,14 +199,43 @@ def _retrieve_operators(module_path, module_dict): |
|
|
|
Returns: |
|
|
|
str: module_name, operators in module. |
|
|
|
""" |
|
|
|
|
|
|
|
def _whether_to_lift(sub_module): |
|
|
|
"""Whether to lift a scope according to its depth.""" |
|
|
|
return max(*[len(m.split("/")) for m in sub_module]) > 2 |
|
|
|
|
|
|
|
def _lift(sub_module): |
|
|
|
"""Lift nodes upper.""" |
|
|
|
nonlocal added_module |
|
|
|
lifted_submodule = [] |
|
|
|
continuity_idx = -1 |
|
|
|
lift_needed = _whether_to_lift(sub_module) |
|
|
|
for m in sub_module: |
|
|
|
scopes = m.split("/") |
|
|
|
if lift_needed and len(scopes) == 3: |
|
|
|
# If the scope depth is 3, like ModuleX/ModuleY/Gemm, |
|
|
|
# then we lift ModuleY to top level. |
|
|
|
md_name, md_idx = scopes[-2].split("_") |
|
|
|
if continuity_idx != int(md_idx): |
|
|
|
continuity_idx = int(md_idx) |
|
|
|
added_module[md_name] = added_module.setdefault(md_name, -1) + 1 |
|
|
|
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") |
|
|
|
continue |
|
|
|
if lift_needed and len(scopes) == 2: |
|
|
|
# If the module is required to lifted, then lift leaf node to parent. |
|
|
|
lifted_submodule.append(scopes[-1]) |
|
|
|
continue |
|
|
|
# If lift is not required, then add it directly. |
|
|
|
lifted_submodule.append(m) |
|
|
|
return lifted_submodule |
|
|
|
|
|
|
|
added_module = dict() |
|
|
|
node_in_pattern = module_path.pattern.ptn_items |
|
|
|
node_list = [] |
|
|
|
for node in node_in_pattern: |
|
|
|
if module_dict.get(node): |
|
|
|
node_list += _scope_name_deduplication(node, |
|
|
|
module_dict[node], |
|
|
|
added_module) |
|
|
|
sub_scope = _scope_name_deduplication(node, module_dict[node], added_module) |
|
|
|
node_list += _lift(sub_scope) |
|
|
|
else: |
|
|
|
node_list.append(node) |
|
|
|
val = [f"{module_path.pattern.module_name}/{node}" for node in node_list] |
|
|
|
@@ -231,7 +271,7 @@ def flatten_graph(graph): |
|
|
|
Returns: |
|
|
|
list[str], corresponding scope name. |
|
|
|
""" |
|
|
|
return [f"Model/{node.op_type}" for _, node in graph.node_collection.items()] |
|
|
|
return [f"Model/{node.op_type}" for node in graph.node_collection.values()] |
|
|
|
|
|
|
|
|
|
|
|
def validate_topo_order_succession(): |
|
|
|
@@ -280,9 +320,6 @@ def generate_scope_name(data_loader): |
|
|
|
""" |
|
|
|
init_dag = _build_connection(data_loader) |
|
|
|
try: |
|
|
|
if not validate_topo_order_succession(): |
|
|
|
raise ValueError("Topological order is not successive.") |
|
|
|
|
|
|
|
result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) |
|
|
|
topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) |
|
|
|
|
|
|
|
@@ -294,4 +331,5 @@ def generate_scope_name(data_loader): |
|
|
|
|
|
|
|
except (ValueError, IndexError, AttributeError, KeyError) as _: |
|
|
|
topo_order_with_scope_name_list = flatten_graph(init_dag) |
|
|
|
|
|
|
|
return topo_order_with_scope_name_list |