Browse Source

Sub-graph searching and merging module support transformer arch.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
9568fce57e
6 changed files with 117 additions and 26 deletions
  1. +50
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  2. +8
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  3. +2
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py
  4. +1
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  5. +49
    -17
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  6. +7
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 50
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -16,6 +16,8 @@

__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern

@@ -59,11 +61,16 @@ def register_pattern(ptn_name, in_degree, out_degree):
def _reg(pattern):
result = pattern()
if not result:
return
return pattern
if ptn_name in BUILT_IN_PATTERN:
raise KeyError(f"{ptn_name} exists, `ptn_name` must be unique.")

BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result),
in_degree, out_degree,
ptn_items=result)
BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length)
BUILT_IN_PATTERN[ptn_name].ptn_name = ptn_name
return pattern

return _reg

@@ -112,3 +119,45 @@ def _up_sampling_in_op10():
return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
]


@register_pattern("Multi-Head-Attention", 2, 1)
@register_module_name("MultiHeadAttn", 2, 1)
def _multi_head_attention():
return [
"MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape",
"Transpose", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax",
"MatMul", "Transpose", "Reshape", "MatMul", "Add"
]


@register_pattern("Layer-Normalization", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm():
return [
"ReduceMean", "Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]


@register_pattern("Layer-Normalization-with-cast", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm_with_cast():
return [
"ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]


@register_pattern("GeLU", 1, 1)
@register_module_name("GeLU", 1, 1)
def _gelu():
return [
"Div", "Erf", "Add", "Mul", "Mul"
]


@register_pattern("Linear", 1, 1)
@register_module_name("Linear", 1, 1)
def _linear():
return [
"MatMul", "Add"
]

+ 8
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util
MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 0.07
MAX_ITERATION_DEPTH = 16
SATISFIED_SCORE = 1.5
SATISFIED_SCORE = 0.74
ACCEPTABLE_RESULT_COUNT = 32
PTN_COVERAGE_THRESHOLD = 0.65
# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage.
@@ -126,6 +126,7 @@ class AlgorithmContext:
node_collection = None
precursor_table = {}
successor_table = {}
outputs_table = {}

def set_init_node_collection(self, nd_col):
"""Init node_collection."""
@@ -158,7 +159,12 @@ class AlgorithmContext:
pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp),
reverse=True)
if len(pattern_arr) > AlgorithmContext.beam_width:
pattern_arr = pattern_arr[:self.beam_width]
new_pattern_arr = pattern_arr[:self.beam_width]
# Avoid dropping built-in pattern, because built-in patterns are much
# more potential.
for i in range(self.beam_width):
if pattern_arr[i][1].additional_score != 0:
new_pattern_arr.append(pattern_arr[i])
res = OrderedDict()
for i, (key, ptn) in enumerate(pattern_arr):
if ptn.count <= AlgorithmContext.MIN_FREQUENCY:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py View File

@@ -49,10 +49,11 @@ def register_module_name(md_name: str, in_degree: int, out_degree: int):
def _reg(pattern):
result = pattern()
if not result:
return
return pattern
BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result),
in_degree, out_degree,
ptn_items=result)] = md_name
return pattern

return _reg



+ 1
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -25,6 +25,7 @@ class Pattern:
self.start_index = []
self.end_index = []
self.module_name = None
self.ptn_name = ""
self.ptn_length = pattern_length
self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items
self.in_degree = in_degree


+ 49
- 17
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -15,14 +15,15 @@
"""Declare search path related."""
import copy
import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from typing import Dict, List, Callable, Union

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode
@@ -80,10 +81,13 @@ def _is_valid_pattern(pattern, dag):
"""
if not pattern:
return False
first_op = dag.node_collection[list(pattern.keys())[0]].op_type
head = dag.node_collection[list(pattern.keys())[0]]
op_type = head.op_type
if isinstance(head, MergedONNXNode) and "LayerNorm" in head.known_module_name:
return False
if len(pattern) == 1:
return False
if first_op in OptimizeRules.CAN_NOT_BE_HEAD:
if op_type in OptimizeRules.CAN_NOT_BE_HEAD:
return False
return True

@@ -133,9 +137,12 @@ def random_name(module_name):
class MergedONNXNode(BaseNode):
"""Define merged onnx node."""

def __init__(self, name, module_name, ori_nodes):
def __init__(self, name, module_name, ori_nodes, inputs, outputs, known_module_name):
super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name)
self.nodes = ori_nodes
self.inputs = inputs
self.outputs = outputs
self.known_module_name = known_module_name if known_module_name else ""

def get_name(self):
return self.name
@@ -192,23 +199,30 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list],
dag (DagGraph): Graph instance.

Returns:
tuple[int, int], in degree and out degree.
tuple[int, int, set, set], in degree, out degree, precursors and successors.
"""
in_node = set()
node_in_seq = set()
items = sequence if isinstance(sequence, list) else sequence.keys()
for item in items:
node_in_seq.add(item.name if not isinstance(item, str) else item)
out_degree = 0
out_node = set()
for item in items:
item = item.name if not isinstance(item, str) else item
for ipt in dag.precursor_table[item]:
in_node.add(ipt)
for opt in dag.successor_table[item]:
if opt not in node_in_seq:
out_degree += 1
if item not in context.outputs_table:
out = dag.node_collection[item].outputs
else:
out = set(context.outputs_table[item])
out_node.update(out)
in_degree = len(in_node - node_in_seq)
return in_degree, out_degree
# Because only record nodes that outputs are referred by other nodes,
# the outputs number of each node must be calculated.
out_degree = len(out_node)
return in_degree, out_degree, in_node - node_in_seq, out_node


def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph):
@@ -312,14 +326,14 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str
if not matched:
cur_idx += 1
continue
in_degree, out_degree = _get_pattern_degree(init_pattern, dag)
in_degree, out_degree, _, _ = _get_pattern_degree(init_pattern, dag)
if in_degree != BUILT_IN_PATTERN[k].in_degree or out_degree != BUILT_IN_PATTERN[k].out_degree:
cur_idx += 1
continue
ptn_key = f"{BUILT_IN_PATTERN[k].pattern}" \
f"[{BUILT_IN_PATTERN[k].in_degree}, {BUILT_IN_PATTERN[k].out_degree}]"
if ptn_key not in pattern:
pattern[ptn_key] = BUILT_IN_PATTERN[k]
pattern[ptn_key] = copy.deepcopy(BUILT_IN_PATTERN[k])

pattern[ptn_key].insert(cur_idx, ptn_len)
cur_idx = cur_idx + 1
@@ -375,8 +389,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
offset=cur_idx - jump_step,
dag=dag)

in_degree, out_degree = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE or in_degree > MAX_OUT_DEGREE:
in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE:
cur_idx += 1
continue

@@ -391,9 +405,24 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence))
cur_idx = cur_idx + 1

pattern = _post_process_overlap(pattern)
return pattern


def _post_process_overlap(patterns) -> Dict:
"""Post process overlap of found pattern."""
for name in patterns:
prev_end = patterns[name].end_index[0]
idx = 1
while idx < len(patterns[name].end_index):
if patterns[name].start_index[idx] <= prev_end:
patterns[name].start_index.pop(idx)
patterns[name].end_index.pop(idx)
continue
idx += 1
return patterns


class SearchPath:
"""
Use SearchPath to store the search path.
@@ -510,7 +539,7 @@ class SearchPath:
path_length += 1
continue

in_degree, out_degree = _get_pattern_degree(visited_node, self.graph)
in_degree, out_degree, inputs, outputs = _get_pattern_degree(visited_node, self.graph)
if in_degree != pattern.in_degree or out_degree != pattern.out_degree:
topo_order.extend(visited_node)
index += j + 1
@@ -520,7 +549,10 @@ class SearchPath:
inverted_index[path_length] = [j + index for j in range(pattern_len)]
new_node = MergedONNXNode(name=random_name(pattern.module_name),
module_name=pattern.module_name,
ori_nodes=visited_node[:])
ori_nodes=visited_node[:],
inputs=inputs,
outputs=outputs,
known_module_name=pattern.known_module_name)
self._reconnect(new_node)
self.graph.node_collection[new_node.name] = new_node
topo_order.append(new_node)
@@ -625,7 +657,7 @@ class SearchPath:
def _actual_val(self):
"""Calculate ground-truth score of the path."""
bonus = self._cal_bonus()
return bonus * 0.1 + self.replacement_ratio * 0.9
return (bonus + self.replacement_ratio) / 2

def __lt__(self, other):
"""Override `<` operator."""


+ 7
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -39,9 +39,10 @@ def _is_satisfied(path):
"""
if len(path.recursion_path) > MAX_ITERATION_DEPTH:
return True
if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]):
candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()])
if not path.new_pattern or not candidate_eval:
return True
if path.evaluate_score() > SATISFIED_SCORE:
if path.evaluate_score() > SATISFIED_SCORE and not candidate_eval:
return True
return False

@@ -92,6 +93,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
if _is_satisfied(cur_path):
available_path.append(cur_path)
deduplicate_path.add(cur_path.hash_of_aft_repl)
_is_satisfied(cur_path)
continue

if len(available_path) >= ACCEPTABLE_RESULT_COUNT:
@@ -154,6 +156,7 @@ def _retrieve_scope_name(found_path):
Args:
found_path: Found path.
"""
_add_known_module_name(found_path)
module_name_mgr = dict()

module_dict = dict()
@@ -257,6 +260,7 @@ def _build_connection(loader):
for node_name, node in loader.nodes_dict.items():
context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list

dag = DagGraph(nodes=context.node_collection.copy(),
precursor=context.precursor_table.copy(),
@@ -308,6 +312,7 @@ def _add_known_module_name(search_path):
for it in search_path.recursion_path:
if it.pattern.known_module_name:
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name
return ctx


@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
@@ -329,9 +334,6 @@ def generate_scope_name(data_loader):
if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict):
topo_order_with_scope_name_list = flatten_graph(init_dag)

if result:
_add_known_module_name(result)

except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag)



Loading…
Cancel
Save