| @@ -13,6 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Searcher of scope name.""" | """Searcher of scope name.""" | ||||
| from .searcher import generate_scope_name | |||||
| __all__ = ["generate_scope_name"] | __all__ = ["generate_scope_name"] | ||||
| from .searcher import generate_scope_name | |||||
| @@ -13,11 +13,34 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Introduce some standard pattern into MindConverter.""" | """Introduce some standard pattern into MindConverter.""" | ||||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | |||||
| from .common import cal_matching_score | |||||
| from .pattern import Pattern | from .pattern import Pattern | ||||
| BUILT_IN_PATTERN = dict() | BUILT_IN_PATTERN = dict() | ||||
| def is_built_in_pattern(pattern: Pattern): | |||||
| """ | |||||
| Whether the module name was built-in. | |||||
| Args: | |||||
| pattern (Pattern): Found pattern. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| for ptn in BUILT_IN_PATTERN: | |||||
| if BUILT_IN_PATTERN[ptn].ptn_length == pattern.ptn_length and \ | |||||
| BUILT_IN_PATTERN[ptn].in_degree == pattern.in_degree and \ | |||||
| BUILT_IN_PATTERN[ptn].out_degree == pattern.out_degree and \ | |||||
| BUILT_IN_PATTERN[ptn].ptn_items == pattern.ptn_items: | |||||
| return True | |||||
| return False | |||||
| def register_pattern(ptn_name, in_degree, out_degree): | def register_pattern(ptn_name, in_degree, out_degree): | ||||
| """ | """ | ||||
| Register pattern to MindConverter. | Register pattern to MindConverter. | ||||
| @@ -40,6 +63,7 @@ def register_pattern(ptn_name, in_degree, out_degree): | |||||
| BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), | BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), | ||||
| in_degree, out_degree, | in_degree, out_degree, | ||||
| ptn_items=result) | ptn_items=result) | ||||
| BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length) | |||||
| return _reg | return _reg | ||||
| @@ -76,4 +100,15 @@ def _convbnrelux3_convbn_add_relu(): | |||||
| "Conv", "BatchNormalization", "Add", "Relu"] | "Conv", "BatchNormalization", "Add", "Relu"] | ||||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern"] | |||||
| @register_pattern("UnSampling-op12", 1, 1) | |||||
| def _up_sampling_in_op12(): | |||||
| return [ | |||||
| "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" | |||||
| ] | |||||
| @register_pattern("UpSampling-op10", 1, 1) | |||||
| def _up_sampling_in_op10(): | |||||
| return [ | |||||
| "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" | |||||
| ] | |||||
| @@ -13,6 +13,18 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Declare generic variable and functions.""" | """Declare generic variable and functions.""" | ||||
| __all__ = ["context", | |||||
| "gen_hash_key", | |||||
| "DagGraph", | |||||
| "MAX_OUT_DEGREE", | |||||
| "cal_matching_score", | |||||
| "ACCEPTABLE_RESULT_COUNT", | |||||
| "MINI_FREQUENCY", | |||||
| "SATISFIED_SCORE", | |||||
| "MAX_ITERATION_DEPTH"] | |||||
| import math | |||||
| import copy | import copy | ||||
| import functools | import functools | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| @@ -23,8 +35,21 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util | |||||
| MAX_OUT_DEGREE = 1 | MAX_OUT_DEGREE = 1 | ||||
| MINI_FREQUENCY = 4 | MINI_FREQUENCY = 4 | ||||
| MAX_ITERATION_DEPTH = 4 | MAX_ITERATION_DEPTH = 4 | ||||
| SATISFIED_SCORE = 0.55 | |||||
| SATISFIED_SCORE = 0.6 | |||||
| ACCEPTABLE_RESULT_COUNT = 16 | ACCEPTABLE_RESULT_COUNT = 16 | ||||
| PTN_COVERAGE_THRESHOLD = 0.65 | |||||
| # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. | |||||
| IGNORE_PTN_LEN = 5 | |||||
| def cal_matching_score(sequence_len: int): | |||||
| """ | |||||
| Calculate matching score. | |||||
| Args: | |||||
| sequence_len (int): Pattern length. | |||||
| """ | |||||
| return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1 | |||||
| class CmpRelation: | class CmpRelation: | ||||
| @@ -79,8 +104,8 @@ class AlgorithmContext: | |||||
| found_pattern = {} | found_pattern = {} | ||||
| visited = set() | visited = set() | ||||
| beam_width = 5 | beam_width = 5 | ||||
| total_len = 0 | |||||
| MIN_FREQUENCY = 1 | MIN_FREQUENCY = 1 | ||||
| total_len = 0 | |||||
| node_collection = None | node_collection = None | ||||
| precursor_table = {} | precursor_table = {} | ||||
| successor_table = {} | successor_table = {} | ||||
| @@ -120,6 +145,10 @@ class AlgorithmContext: | |||||
| return CmpRelation.GREATER | return CmpRelation.GREATER | ||||
| if x[1].count < y[1].count: | if x[1].count < y[1].count: | ||||
| return CmpRelation.LESS | return CmpRelation.LESS | ||||
| if x[1].additional_score > y[1].additional_score: | |||||
| return CmpRelation.GREATER | |||||
| if x[1].additional_score < y[1].additional_score: | |||||
| return CmpRelation.LESS | |||||
| if x[1].ptn_length > y[1].ptn_length: | if x[1].ptn_length > y[1].ptn_length: | ||||
| return CmpRelation.GREATER | return CmpRelation.GREATER | ||||
| if x[1].ptn_length < y[1].ptn_length: | if x[1].ptn_length < y[1].ptn_length: | ||||
| @@ -136,11 +165,19 @@ class AlgorithmContext: | |||||
| continue | continue | ||||
| skip = False | skip = False | ||||
| for j, (_, candidate) in enumerate(pattern_arr): | for j, (_, candidate) in enumerate(pattern_arr): | ||||
| if i == j: | |||||
| if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN): | |||||
| continue | continue | ||||
| if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]: | |||||
| if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \ | |||||
| and ptn.pattern in candidate.pattern: | |||||
| skip = True | skip = True | ||||
| break | break | ||||
| if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \ | |||||
| and ptn.additional_score == 0 and candidate.pattern in ptn.pattern: | |||||
| ratio = candidate.ptn_length / ptn.ptn_length | |||||
| if ratio >= PTN_COVERAGE_THRESHOLD: | |||||
| skip = True | |||||
| break | |||||
| if skip: | if skip: | ||||
| continue | continue | ||||
| res[key] = ptn | res[key] = ptn | ||||
| @@ -148,12 +185,3 @@ class AlgorithmContext: | |||||
| context = AlgorithmContext() | context = AlgorithmContext() | ||||
| __all__ = ["context", | |||||
| "gen_hash_key", | |||||
| "DagGraph", | |||||
| "MAX_OUT_DEGREE", | |||||
| "MAX_ITERATION_DEPTH", | |||||
| "SATISFIED_SCORE", | |||||
| "MINI_FREQUENCY", | |||||
| "ACCEPTABLE_RESULT_COUNT"] | |||||
| @@ -13,12 +13,28 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Introduce some standard pattern name into MindConverter.""" | """Introduce some standard pattern name into MindConverter.""" | ||||
| __all__ = ["register_module_name", "is_built_in_module_name", "BUILT_IN_MODULE_NAME"] | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | ||||
| PLACEHOLDER = "PLC" | PLACEHOLDER = "PLC" | ||||
| BUILT_IN_MODULE_NAME = dict() | BUILT_IN_MODULE_NAME = dict() | ||||
| def is_built_in_module_name(module_name: str): | |||||
| """ | |||||
| Whether the module name was built-in. | |||||
| Args: | |||||
| module_name (str): Module name. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| return module_name.split("_")[0] in BUILT_IN_MODULE_NAME | |||||
| def register_module_name(md_name: str, in_degree: int, out_degree: int): | def register_module_name(md_name: str, in_degree: int, out_degree: int): | ||||
| """ | """ | ||||
| Register pattern to MindConverter. | Register pattern to MindConverter. | ||||
| @@ -71,3 +87,17 @@ def _basic_conv_block_0(): | |||||
| def _conv_bn(): | def _conv_bn(): | ||||
| """Add basic conv block.""" | """Add basic conv block.""" | ||||
| return ["Conv", "BatchNormalization"] | return ["Conv", "BatchNormalization"] | ||||
| @register_module_name("UnSample", 1, 1) | |||||
| def _up_sampling_in_op12(): | |||||
| return [ | |||||
| "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" | |||||
| ] | |||||
| @register_module_name("UnSample", 1, 1) | |||||
| def _up_sampling_in_op10(): | |||||
| return [ | |||||
| "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" | |||||
| ] | |||||
| @@ -31,6 +31,9 @@ class Pattern: | |||||
| self.out_degree = out_degree | self.out_degree = out_degree | ||||
| self.head = self.ptn_items[0] | self.head = self.ptn_items[0] | ||||
| self.tail = self.ptn_items[-1] | self.tail = self.ptn_items[-1] | ||||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | |||||
| # the pattern will get additional score. | |||||
| self.additional_score = 0 | |||||
| def insert(self, idx, seq_len): | def insert(self, idx, seq_len): | ||||
| """ | """ | ||||
| @@ -17,10 +17,10 @@ import copy | |||||
| import uuid | import uuid | ||||
| from typing import Dict, List, Callable, Union | from typing import Dict, List, Callable, Union | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE | |||||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | |||||
| from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name | |||||
| from .pattern import Pattern, scope_name_mapping | from .pattern import Pattern, scope_name_mapping | ||||
| from .built_in_pattern import BUILT_IN_PATTERN | |||||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | |||||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | from .pattern_fuzzy_matching import pattern_fuzzy_matching | ||||
| from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | ||||
| @@ -376,6 +376,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, | |||||
| if ptn_key not in pattern: | if ptn_key not in pattern: | ||||
| pattern[ptn_key] = Pattern(ptn, len(found_sequence), | pattern[ptn_key] = Pattern(ptn, len(found_sequence), | ||||
| in_degree=in_degree, out_degree=out_degree) | in_degree=in_degree, out_degree=out_degree) | ||||
| if is_built_in_pattern(pattern[ptn_key]): | |||||
| pattern[ptn_key].additional_score = cal_matching_score(pattern[ptn_key].ptn_length) | |||||
| pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) | pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) | ||||
| cur_idx = cur_idx + 1 | cur_idx = cur_idx + 1 | ||||
| @@ -425,6 +427,7 @@ class SearchPath: | |||||
| } | } | ||||
| self._created_modules[self.pattern.module_name] = self.pattern | self._created_modules[self.pattern.module_name] = self.pattern | ||||
| self.heuristic_v = self._heuristic_val() | self.heuristic_v = self._heuristic_val() | ||||
| self.replacement_ratio = self._repl_ratio() | |||||
| self.actual_v = self._actual_val() | self.actual_v = self._actual_val() | ||||
| def _create_new_order(self): | def _create_new_order(self): | ||||
| @@ -442,6 +445,8 @@ class SearchPath: | |||||
| else: | else: | ||||
| module_name = scope_name_mapping[self.pattern.pattern] | module_name = scope_name_mapping[self.pattern.pattern] | ||||
| self.pattern.module_name = module_name | self.pattern.module_name = module_name | ||||
| if is_built_in_module_name(module_name): | |||||
| self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) | |||||
| topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | ||||
| return topo_order, inverted_index | return topo_order, inverted_index | ||||
| @@ -572,7 +577,12 @@ class SearchPath: | |||||
| self.graph.precursor_table[s_nd] = p_nodes | self.graph.precursor_table[s_nd] = p_nodes | ||||
| def evaluate_score(self): | def evaluate_score(self): | ||||
| """Evaluate path score.""" | |||||
| """ | |||||
| Evaluate path score. | |||||
| Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H | |||||
| = 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H | |||||
| """ | |||||
| return .7 * self.actual_v + .3 * self.heuristic_v | return .7 * self.actual_v + .3 * self.heuristic_v | ||||
| def _cal_merged_module_length(self, ptn): | def _cal_merged_module_length(self, ptn): | ||||
| @@ -594,9 +604,21 @@ class SearchPath: | |||||
| return 1.0 | return 1.0 | ||||
| return max(res) | return max(res) | ||||
| def _cal_bonus(self): | |||||
| """Calculate total pattern length.""" | |||||
| score = self.pattern.additional_score | |||||
| for search_path in self.recursion_path: | |||||
| score += search_path.pattern.additional_score | |||||
| return score | |||||
| def _repl_ratio(self): | |||||
| """Calculate replacement ratio of current path.""" | |||||
| return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() | |||||
| def _actual_val(self): | def _actual_val(self): | ||||
| """Calculate ground-truth score of the path.""" | """Calculate ground-truth score of the path.""" | ||||
| return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() | |||||
| bonus = self._cal_bonus() | |||||
| return bonus * 0.1 + self.replacement_ratio * 0.9 | |||||
| def __lt__(self, other): | def __lt__(self, other): | ||||
| """Override `<` operator.""" | """Override `<` operator.""" | ||||