Browse Source

Sub-graph searching and merging module support transformer arch.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
9568fce57e
6 changed files with 117 additions and 26 deletions
  1. +50
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  2. +8
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  3. +2
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py
  4. +1
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  5. +49
    -17
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  6. +7
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 50
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -16,6 +16,8 @@


__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]


from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern


@@ -59,11 +61,16 @@ def register_pattern(ptn_name, in_degree, out_degree):
def _reg(pattern): def _reg(pattern):
result = pattern() result = pattern()
if not result: if not result:
return
return pattern
if ptn_name in BUILT_IN_PATTERN:
raise KeyError(f"{ptn_name} exists, `ptn_name` must be unique.")

BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result),
in_degree, out_degree, in_degree, out_degree,
ptn_items=result) ptn_items=result)
BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length) BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length)
BUILT_IN_PATTERN[ptn_name].ptn_name = ptn_name
return pattern


return _reg return _reg


@@ -112,3 +119,45 @@ def _up_sampling_in_op10():
return [ return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
] ]


@register_pattern("Multi-Head-Attention", 2, 1)
@register_module_name("MultiHeadAttn", 2, 1)
def _multi_head_attention():
return [
"MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape",
"Transpose", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax",
"MatMul", "Transpose", "Reshape", "MatMul", "Add"
]


@register_pattern("Layer-Normalization", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm():
return [
"ReduceMean", "Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]


@register_pattern("Layer-Normalization-with-cast", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm_with_cast():
return [
"ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]


@register_pattern("GeLU", 1, 1)
@register_module_name("GeLU", 1, 1)
def _gelu():
return [
"Div", "Erf", "Add", "Mul", "Mul"
]


@register_pattern("Linear", 1, 1)
@register_module_name("Linear", 1, 1)
def _linear():
return [
"MatMul", "Add"
]

+ 8
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util
MAX_OUT_DEGREE = 1 MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 0.07 MINI_FREQUENCY = 0.07
MAX_ITERATION_DEPTH = 16 MAX_ITERATION_DEPTH = 16
SATISFIED_SCORE = 1.5
SATISFIED_SCORE = 0.74
ACCEPTABLE_RESULT_COUNT = 32 ACCEPTABLE_RESULT_COUNT = 32
PTN_COVERAGE_THRESHOLD = 0.65 PTN_COVERAGE_THRESHOLD = 0.65
# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage.
@@ -126,6 +126,7 @@ class AlgorithmContext:
node_collection = None node_collection = None
precursor_table = {} precursor_table = {}
successor_table = {} successor_table = {}
outputs_table = {}


def set_init_node_collection(self, nd_col): def set_init_node_collection(self, nd_col):
"""Init node_collection.""" """Init node_collection."""
@@ -158,7 +159,12 @@ class AlgorithmContext:
pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp), pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp),
reverse=True) reverse=True)
if len(pattern_arr) > AlgorithmContext.beam_width: if len(pattern_arr) > AlgorithmContext.beam_width:
pattern_arr = pattern_arr[:self.beam_width]
new_pattern_arr = pattern_arr[:self.beam_width]
# Avoid dropping built-in pattern, because built-in patterns are much
# more potential.
for i in range(self.beam_width):
if pattern_arr[i][1].additional_score != 0:
new_pattern_arr.append(pattern_arr[i])
res = OrderedDict() res = OrderedDict()
for i, (key, ptn) in enumerate(pattern_arr): for i, (key, ptn) in enumerate(pattern_arr):
if ptn.count <= AlgorithmContext.MIN_FREQUENCY: if ptn.count <= AlgorithmContext.MIN_FREQUENCY:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py View File

@@ -49,10 +49,11 @@ def register_module_name(md_name: str, in_degree: int, out_degree: int):
def _reg(pattern): def _reg(pattern):
result = pattern() result = pattern()
if not result: if not result:
return
return pattern
BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result), BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result),
in_degree, out_degree, in_degree, out_degree,
ptn_items=result)] = md_name ptn_items=result)] = md_name
return pattern


return _reg return _reg




+ 1
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -25,6 +25,7 @@ class Pattern:
self.start_index = [] self.start_index = []
self.end_index = [] self.end_index = []
self.module_name = None self.module_name = None
self.ptn_name = ""
self.ptn_length = pattern_length self.ptn_length = pattern_length
self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items
self.in_degree = in_degree self.in_degree = in_degree


+ 49
- 17
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -15,14 +15,15 @@
"""Declare search path related.""" """Declare search path related."""
import copy import copy
import uuid import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Callable, Union

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score MAX_OUT_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode
@@ -80,10 +81,13 @@ def _is_valid_pattern(pattern, dag):
""" """
if not pattern: if not pattern:
return False return False
first_op = dag.node_collection[list(pattern.keys())[0]].op_type
head = dag.node_collection[list(pattern.keys())[0]]
op_type = head.op_type
if isinstance(head, MergedONNXNode) and "LayerNorm" in head.known_module_name:
return False
if len(pattern) == 1: if len(pattern) == 1:
return False return False
if first_op in OptimizeRules.CAN_NOT_BE_HEAD:
if op_type in OptimizeRules.CAN_NOT_BE_HEAD:
return False return False
return True return True


@@ -133,9 +137,12 @@ def random_name(module_name):
class MergedONNXNode(BaseNode): class MergedONNXNode(BaseNode):
"""Define merged onnx node.""" """Define merged onnx node."""


def __init__(self, name, module_name, ori_nodes):
def __init__(self, name, module_name, ori_nodes, inputs, outputs, known_module_name):
super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name) super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name)
self.nodes = ori_nodes self.nodes = ori_nodes
self.inputs = inputs
self.outputs = outputs
self.known_module_name = known_module_name if known_module_name else ""


def get_name(self): def get_name(self):
return self.name return self.name
@@ -192,23 +199,30 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list],
dag (DagGraph): Graph instance. dag (DagGraph): Graph instance.


Returns: Returns:
tuple[int, int], in degree and out degree.
tuple[int, int, set, set], in degree, out degree, precursors and successors.
""" """
in_node = set() in_node = set()
node_in_seq = set() node_in_seq = set()
items = sequence if isinstance(sequence, list) else sequence.keys() items = sequence if isinstance(sequence, list) else sequence.keys()
for item in items: for item in items:
node_in_seq.add(item.name if not isinstance(item, str) else item) node_in_seq.add(item.name if not isinstance(item, str) else item)
out_degree = 0
out_node = set()
for item in items: for item in items:
item = item.name if not isinstance(item, str) else item item = item.name if not isinstance(item, str) else item
for ipt in dag.precursor_table[item]: for ipt in dag.precursor_table[item]:
in_node.add(ipt) in_node.add(ipt)
for opt in dag.successor_table[item]: for opt in dag.successor_table[item]:
if opt not in node_in_seq: if opt not in node_in_seq:
out_degree += 1
if item not in context.outputs_table:
out = dag.node_collection[item].outputs
else:
out = set(context.outputs_table[item])
out_node.update(out)
in_degree = len(in_node - node_in_seq) in_degree = len(in_node - node_in_seq)
return in_degree, out_degree
# Because only record nodes that outputs are referred by other nodes,
# the outputs number of each node must be calculated.
out_degree = len(out_node)
return in_degree, out_degree, in_node - node_in_seq, out_node




def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph): def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph):
@@ -312,14 +326,14 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str
if not matched: if not matched:
cur_idx += 1 cur_idx += 1
continue continue
in_degree, out_degree = _get_pattern_degree(init_pattern, dag)
in_degree, out_degree, _, _ = _get_pattern_degree(init_pattern, dag)
if in_degree != BUILT_IN_PATTERN[k].in_degree or out_degree != BUILT_IN_PATTERN[k].out_degree: if in_degree != BUILT_IN_PATTERN[k].in_degree or out_degree != BUILT_IN_PATTERN[k].out_degree:
cur_idx += 1 cur_idx += 1
continue continue
ptn_key = f"{BUILT_IN_PATTERN[k].pattern}" \ ptn_key = f"{BUILT_IN_PATTERN[k].pattern}" \
f"[{BUILT_IN_PATTERN[k].in_degree}, {BUILT_IN_PATTERN[k].out_degree}]" f"[{BUILT_IN_PATTERN[k].in_degree}, {BUILT_IN_PATTERN[k].out_degree}]"
if ptn_key not in pattern: if ptn_key not in pattern:
pattern[ptn_key] = BUILT_IN_PATTERN[k]
pattern[ptn_key] = copy.deepcopy(BUILT_IN_PATTERN[k])


pattern[ptn_key].insert(cur_idx, ptn_len) pattern[ptn_key].insert(cur_idx, ptn_len)
cur_idx = cur_idx + 1 cur_idx = cur_idx + 1
@@ -375,8 +389,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
offset=cur_idx - jump_step, offset=cur_idx - jump_step,
dag=dag) dag=dag)


in_degree, out_degree = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE or in_degree > MAX_OUT_DEGREE:
in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE:
cur_idx += 1 cur_idx += 1
continue continue


@@ -391,9 +405,24 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence))
cur_idx = cur_idx + 1 cur_idx = cur_idx + 1


pattern = _post_process_overlap(pattern)
return pattern return pattern




def _post_process_overlap(patterns) -> Dict:
"""Post process overlap of found pattern."""
for name in patterns:
prev_end = patterns[name].end_index[0]
idx = 1
while idx < len(patterns[name].end_index):
if patterns[name].start_index[idx] <= prev_end:
patterns[name].start_index.pop(idx)
patterns[name].end_index.pop(idx)
continue
idx += 1
return patterns


class SearchPath: class SearchPath:
""" """
Use SearchPath to store the search path. Use SearchPath to store the search path.
@@ -510,7 +539,7 @@ class SearchPath:
path_length += 1 path_length += 1
continue continue


in_degree, out_degree = _get_pattern_degree(visited_node, self.graph)
in_degree, out_degree, inputs, outputs = _get_pattern_degree(visited_node, self.graph)
if in_degree != pattern.in_degree or out_degree != pattern.out_degree: if in_degree != pattern.in_degree or out_degree != pattern.out_degree:
topo_order.extend(visited_node) topo_order.extend(visited_node)
index += j + 1 index += j + 1
@@ -520,7 +549,10 @@ class SearchPath:
inverted_index[path_length] = [j + index for j in range(pattern_len)] inverted_index[path_length] = [j + index for j in range(pattern_len)]
new_node = MergedONNXNode(name=random_name(pattern.module_name), new_node = MergedONNXNode(name=random_name(pattern.module_name),
module_name=pattern.module_name, module_name=pattern.module_name,
ori_nodes=visited_node[:])
ori_nodes=visited_node[:],
inputs=inputs,
outputs=outputs,
known_module_name=pattern.known_module_name)
self._reconnect(new_node) self._reconnect(new_node)
self.graph.node_collection[new_node.name] = new_node self.graph.node_collection[new_node.name] = new_node
topo_order.append(new_node) topo_order.append(new_node)
@@ -625,7 +657,7 @@ class SearchPath:
def _actual_val(self): def _actual_val(self):
"""Calculate ground-truth score of the path.""" """Calculate ground-truth score of the path."""
bonus = self._cal_bonus() bonus = self._cal_bonus()
return bonus * 0.1 + self.replacement_ratio * 0.9
return (bonus + self.replacement_ratio) / 2


def __lt__(self, other): def __lt__(self, other):
"""Override `<` operator.""" """Override `<` operator."""


+ 7
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -39,9 +39,10 @@ def _is_satisfied(path):
""" """
if len(path.recursion_path) > MAX_ITERATION_DEPTH: if len(path.recursion_path) > MAX_ITERATION_DEPTH:
return True return True
if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]):
candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()])
if not path.new_pattern or not candidate_eval:
return True return True
if path.evaluate_score() > SATISFIED_SCORE:
if path.evaluate_score() > SATISFIED_SCORE and not candidate_eval:
return True return True
return False return False


@@ -92,6 +93,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
if _is_satisfied(cur_path): if _is_satisfied(cur_path):
available_path.append(cur_path) available_path.append(cur_path)
deduplicate_path.add(cur_path.hash_of_aft_repl) deduplicate_path.add(cur_path.hash_of_aft_repl)
_is_satisfied(cur_path)
continue continue


if len(available_path) >= ACCEPTABLE_RESULT_COUNT: if len(available_path) >= ACCEPTABLE_RESULT_COUNT:
@@ -154,6 +156,7 @@ def _retrieve_scope_name(found_path):
Args: Args:
found_path: Found path. found_path: Found path.
""" """
_add_known_module_name(found_path)
module_name_mgr = dict() module_name_mgr = dict()


module_dict = dict() module_dict = dict()
@@ -257,6 +260,7 @@ def _build_connection(loader):
for node_name, node in loader.nodes_dict.items(): for node_name, node in loader.nodes_dict.items():
context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list


dag = DagGraph(nodes=context.node_collection.copy(), dag = DagGraph(nodes=context.node_collection.copy(),
precursor=context.precursor_table.copy(), precursor=context.precursor_table.copy(),
@@ -308,6 +312,7 @@ def _add_known_module_name(search_path):
for it in search_path.recursion_path: for it in search_path.recursion_path:
if it.pattern.known_module_name: if it.pattern.known_module_name:
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name
return ctx




@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
@@ -329,9 +334,6 @@ def generate_scope_name(data_loader):
if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict):
topo_order_with_scope_name_list = flatten_graph(init_dag) topo_order_with_scope_name_list = flatten_graph(init_dag)


if result:
_add_known_module_name(result)

except (ValueError, IndexError, AttributeError, KeyError) as _: except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag) topo_order_with_scope_name_list = flatten_graph(init_dag)




Loading…
Cancel
Save