|
|
|
@@ -15,14 +15,15 @@ |
|
|
|
"""Declare search path related.""" |
|
|
|
import copy |
|
|
|
import uuid |
|
|
|
from typing import Dict, List, Callable, Union |
|
|
|
from collections import OrderedDict |
|
|
|
from typing import Dict, List, Callable, Union |
|
|
|
|
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ |
|
|
|
is_built_in_pattern |
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ |
|
|
|
MAX_OUT_DEGREE, cal_matching_score |
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME |
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping |
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ |
|
|
|
is_built_in_pattern |
|
|
|
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ |
|
|
|
pattern_fuzzy_matching |
|
|
|
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode |
|
|
|
@@ -80,10 +81,13 @@ def _is_valid_pattern(pattern, dag): |
|
|
|
""" |
|
|
|
if not pattern: |
|
|
|
return False |
|
|
|
first_op = dag.node_collection[list(pattern.keys())[0]].op_type |
|
|
|
head = dag.node_collection[list(pattern.keys())[0]] |
|
|
|
op_type = head.op_type |
|
|
|
if isinstance(head, MergedONNXNode) and "LayerNorm" in head.known_module_name: |
|
|
|
return False |
|
|
|
if len(pattern) == 1: |
|
|
|
return False |
|
|
|
if first_op in OptimizeRules.CAN_NOT_BE_HEAD: |
|
|
|
if op_type in OptimizeRules.CAN_NOT_BE_HEAD: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@@ -133,9 +137,12 @@ def random_name(module_name): |
|
|
|
class MergedONNXNode(BaseNode): |
|
|
|
"""Define merged onnx node.""" |
|
|
|
|
|
|
|
def __init__(self, name, module_name, ori_nodes): |
|
|
|
def __init__(self, name, module_name, ori_nodes, inputs, outputs, known_module_name): |
|
|
|
super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name) |
|
|
|
self.nodes = ori_nodes |
|
|
|
self.inputs = inputs |
|
|
|
self.outputs = outputs |
|
|
|
self.known_module_name = known_module_name if known_module_name else "" |
|
|
|
|
|
|
|
def get_name(self): |
|
|
|
return self.name |
|
|
|
@@ -192,23 +199,30 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list], |
|
|
|
dag (DagGraph): Graph instance. |
|
|
|
|
|
|
|
Returns: |
|
|
|
tuple[int, int], in degree and out degree. |
|
|
|
tuple[int, int, set, set], in degree, out degree, precursors and successors. |
|
|
|
""" |
|
|
|
in_node = set() |
|
|
|
node_in_seq = set() |
|
|
|
items = sequence if isinstance(sequence, list) else sequence.keys() |
|
|
|
for item in items: |
|
|
|
node_in_seq.add(item.name if not isinstance(item, str) else item) |
|
|
|
out_degree = 0 |
|
|
|
out_node = set() |
|
|
|
for item in items: |
|
|
|
item = item.name if not isinstance(item, str) else item |
|
|
|
for ipt in dag.precursor_table[item]: |
|
|
|
in_node.add(ipt) |
|
|
|
for opt in dag.successor_table[item]: |
|
|
|
if opt not in node_in_seq: |
|
|
|
out_degree += 1 |
|
|
|
if item not in context.outputs_table: |
|
|
|
out = dag.node_collection[item].outputs |
|
|
|
else: |
|
|
|
out = set(context.outputs_table[item]) |
|
|
|
out_node.update(out) |
|
|
|
in_degree = len(in_node - node_in_seq) |
|
|
|
return in_degree, out_degree |
|
|
|
# Because only record nodes that outputs are referred by other nodes, |
|
|
|
# the outputs number of each node must be calculated. |
|
|
|
out_degree = len(out_node) |
|
|
|
return in_degree, out_degree, in_node - node_in_seq, out_node |
|
|
|
|
|
|
|
|
|
|
|
def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph): |
|
|
|
@@ -312,14 +326,14 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str |
|
|
|
if not matched: |
|
|
|
cur_idx += 1 |
|
|
|
continue |
|
|
|
in_degree, out_degree = _get_pattern_degree(init_pattern, dag) |
|
|
|
in_degree, out_degree, _, _ = _get_pattern_degree(init_pattern, dag) |
|
|
|
if in_degree != BUILT_IN_PATTERN[k].in_degree or out_degree != BUILT_IN_PATTERN[k].out_degree: |
|
|
|
cur_idx += 1 |
|
|
|
continue |
|
|
|
ptn_key = f"{BUILT_IN_PATTERN[k].pattern}" \ |
|
|
|
f"[{BUILT_IN_PATTERN[k].in_degree}, {BUILT_IN_PATTERN[k].out_degree}]" |
|
|
|
if ptn_key not in pattern: |
|
|
|
pattern[ptn_key] = BUILT_IN_PATTERN[k] |
|
|
|
pattern[ptn_key] = copy.deepcopy(BUILT_IN_PATTERN[k]) |
|
|
|
|
|
|
|
pattern[ptn_key].insert(cur_idx, ptn_len) |
|
|
|
cur_idx = cur_idx + 1 |
|
|
|
@@ -375,8 +389,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, |
|
|
|
offset=cur_idx - jump_step, |
|
|
|
dag=dag) |
|
|
|
|
|
|
|
in_degree, out_degree = _get_pattern_degree(found_sequence, dag) |
|
|
|
if out_degree > MAX_OUT_DEGREE or in_degree > MAX_OUT_DEGREE: |
|
|
|
in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) |
|
|
|
if out_degree > MAX_OUT_DEGREE: |
|
|
|
cur_idx += 1 |
|
|
|
continue |
|
|
|
|
|
|
|
@@ -391,9 +405,24 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, |
|
|
|
pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) |
|
|
|
cur_idx = cur_idx + 1 |
|
|
|
|
|
|
|
pattern = _post_process_overlap(pattern) |
|
|
|
return pattern |
|
|
|
|
|
|
|
|
|
|
|
def _post_process_overlap(patterns) -> Dict: |
|
|
|
"""Post process overlap of found pattern.""" |
|
|
|
for name in patterns: |
|
|
|
prev_end = patterns[name].end_index[0] |
|
|
|
idx = 1 |
|
|
|
while idx < len(patterns[name].end_index): |
|
|
|
if patterns[name].start_index[idx] <= prev_end: |
|
|
|
patterns[name].start_index.pop(idx) |
|
|
|
patterns[name].end_index.pop(idx) |
|
|
|
continue |
|
|
|
idx += 1 |
|
|
|
return patterns |
|
|
|
|
|
|
|
|
|
|
|
class SearchPath: |
|
|
|
""" |
|
|
|
Use SearchPath to store the search path. |
|
|
|
@@ -510,7 +539,7 @@ class SearchPath: |
|
|
|
path_length += 1 |
|
|
|
continue |
|
|
|
|
|
|
|
in_degree, out_degree = _get_pattern_degree(visited_node, self.graph) |
|
|
|
in_degree, out_degree, inputs, outputs = _get_pattern_degree(visited_node, self.graph) |
|
|
|
if in_degree != pattern.in_degree or out_degree != pattern.out_degree: |
|
|
|
topo_order.extend(visited_node) |
|
|
|
index += j + 1 |
|
|
|
@@ -520,7 +549,10 @@ class SearchPath: |
|
|
|
inverted_index[path_length] = [j + index for j in range(pattern_len)] |
|
|
|
new_node = MergedONNXNode(name=random_name(pattern.module_name), |
|
|
|
module_name=pattern.module_name, |
|
|
|
ori_nodes=visited_node[:]) |
|
|
|
ori_nodes=visited_node[:], |
|
|
|
inputs=inputs, |
|
|
|
outputs=outputs, |
|
|
|
known_module_name=pattern.known_module_name) |
|
|
|
self._reconnect(new_node) |
|
|
|
self.graph.node_collection[new_node.name] = new_node |
|
|
|
topo_order.append(new_node) |
|
|
|
@@ -625,7 +657,7 @@ class SearchPath: |
|
|
|
def _actual_val(self): |
|
|
|
"""Calculate ground-truth score of the path.""" |
|
|
|
bonus = self._cal_bonus() |
|
|
|
return bonus * 0.1 + self.replacement_ratio * 0.9 |
|
|
|
return (bonus + self.replacement_ratio) / 2 |
|
|
|
|
|
|
|
def __lt__(self, other): |
|
|
|
"""Override `<` operator.""" |
|
|
|
|