| @@ -13,6 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Searcher of scope name.""" | |||
| from .searcher import generate_scope_name | |||
| __all__ = ["generate_scope_name"] | |||
| from .searcher import generate_scope_name | |||
| @@ -13,11 +13,34 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Introduce some standard pattern into MindConverter.""" | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | |||
| from .common import cal_matching_score | |||
| from .pattern import Pattern | |||
| BUILT_IN_PATTERN = dict() | |||
| def is_built_in_pattern(pattern: Pattern): | |||
| """ | |||
| Whether the module name was built-in. | |||
| Args: | |||
| pattern (Pattern): Found pattern. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| for ptn in BUILT_IN_PATTERN: | |||
| if BUILT_IN_PATTERN[ptn].ptn_length == pattern.ptn_length and \ | |||
| BUILT_IN_PATTERN[ptn].in_degree == pattern.in_degree and \ | |||
| BUILT_IN_PATTERN[ptn].out_degree == pattern.out_degree and \ | |||
| BUILT_IN_PATTERN[ptn].ptn_items == pattern.ptn_items: | |||
| return True | |||
| return False | |||
| def register_pattern(ptn_name, in_degree, out_degree): | |||
| """ | |||
| Register pattern to MindConverter. | |||
| @@ -40,6 +63,7 @@ def register_pattern(ptn_name, in_degree, out_degree): | |||
| BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), | |||
| in_degree, out_degree, | |||
| ptn_items=result) | |||
| BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length) | |||
| return _reg | |||
| @@ -76,4 +100,15 @@ def _convbnrelux3_convbn_add_relu(): | |||
| "Conv", "BatchNormalization", "Add", "Relu"] | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern"] | |||
| @register_pattern("UnSampling-op12", 1, 1) | |||
| def _up_sampling_in_op12(): | |||
| return [ | |||
| "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" | |||
| ] | |||
| @register_pattern("UpSampling-op10", 1, 1) | |||
| def _up_sampling_in_op10(): | |||
| return [ | |||
| "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" | |||
| ] | |||
| @@ -13,6 +13,18 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Declare generic variable and functions.""" | |||
| __all__ = ["context", | |||
| "gen_hash_key", | |||
| "DagGraph", | |||
| "MAX_OUT_DEGREE", | |||
| "cal_matching_score", | |||
| "ACCEPTABLE_RESULT_COUNT", | |||
| "MINI_FREQUENCY", | |||
| "SATISFIED_SCORE", | |||
| "MAX_ITERATION_DEPTH"] | |||
| import math | |||
| import copy | |||
| import functools | |||
| from collections import OrderedDict | |||
| @@ -23,8 +35,21 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util | |||
| MAX_OUT_DEGREE = 1 | |||
| MINI_FREQUENCY = 4 | |||
| MAX_ITERATION_DEPTH = 4 | |||
| SATISFIED_SCORE = 0.55 | |||
| SATISFIED_SCORE = 0.6 | |||
| ACCEPTABLE_RESULT_COUNT = 16 | |||
| PTN_COVERAGE_THRESHOLD = 0.65 | |||
| # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. | |||
| IGNORE_PTN_LEN = 5 | |||
| def cal_matching_score(sequence_len: int): | |||
| """ | |||
| Calculate matching score. | |||
| Args: | |||
| sequence_len (int): Pattern length. | |||
| """ | |||
| return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1 | |||
| class CmpRelation: | |||
| @@ -79,8 +104,8 @@ class AlgorithmContext: | |||
| found_pattern = {} | |||
| visited = set() | |||
| beam_width = 5 | |||
| total_len = 0 | |||
| MIN_FREQUENCY = 1 | |||
| total_len = 0 | |||
| node_collection = None | |||
| precursor_table = {} | |||
| successor_table = {} | |||
| @@ -120,6 +145,10 @@ class AlgorithmContext: | |||
| return CmpRelation.GREATER | |||
| if x[1].count < y[1].count: | |||
| return CmpRelation.LESS | |||
| if x[1].additional_score > y[1].additional_score: | |||
| return CmpRelation.GREATER | |||
| if x[1].additional_score < y[1].additional_score: | |||
| return CmpRelation.LESS | |||
| if x[1].ptn_length > y[1].ptn_length: | |||
| return CmpRelation.GREATER | |||
| if x[1].ptn_length < y[1].ptn_length: | |||
| @@ -136,11 +165,19 @@ class AlgorithmContext: | |||
| continue | |||
| skip = False | |||
| for j, (_, candidate) in enumerate(pattern_arr): | |||
| if i == j: | |||
| if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN): | |||
| continue | |||
| if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]: | |||
| if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \ | |||
| and ptn.pattern in candidate.pattern: | |||
| skip = True | |||
| break | |||
| if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \ | |||
| and ptn.additional_score == 0 and candidate.pattern in ptn.pattern: | |||
| ratio = candidate.ptn_length / ptn.ptn_length | |||
| if ratio >= PTN_COVERAGE_THRESHOLD: | |||
| skip = True | |||
| break | |||
| if skip: | |||
| continue | |||
| res[key] = ptn | |||
| @@ -148,12 +185,3 @@ class AlgorithmContext: | |||
| context = AlgorithmContext() | |||
| __all__ = ["context", | |||
| "gen_hash_key", | |||
| "DagGraph", | |||
| "MAX_OUT_DEGREE", | |||
| "MAX_ITERATION_DEPTH", | |||
| "SATISFIED_SCORE", | |||
| "MINI_FREQUENCY", | |||
| "ACCEPTABLE_RESULT_COUNT"] | |||
| @@ -13,12 +13,28 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Introduce some standard pattern name into MindConverter.""" | |||
| __all__ = ["register_module_name", "is_built_in_module_name", "BUILT_IN_MODULE_NAME"] | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | |||
| PLACEHOLDER = "PLC" | |||
| BUILT_IN_MODULE_NAME = dict() | |||
| def is_built_in_module_name(module_name: str): | |||
| """ | |||
| Whether the module name was built-in. | |||
| Args: | |||
| module_name (str): Module name. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| return module_name.split("_")[0] in BUILT_IN_MODULE_NAME | |||
| def register_module_name(md_name: str, in_degree: int, out_degree: int): | |||
| """ | |||
| Register pattern to MindConverter. | |||
| @@ -71,3 +87,17 @@ def _basic_conv_block_0(): | |||
| def _conv_bn(): | |||
| """Add basic conv block.""" | |||
| return ["Conv", "BatchNormalization"] | |||
| @register_module_name("UnSample", 1, 1) | |||
| def _up_sampling_in_op12(): | |||
| return [ | |||
| "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" | |||
| ] | |||
| @register_module_name("UnSample", 1, 1) | |||
| def _up_sampling_in_op10(): | |||
| return [ | |||
| "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" | |||
| ] | |||
| @@ -31,6 +31,9 @@ class Pattern: | |||
| self.out_degree = out_degree | |||
| self.head = self.ptn_items[0] | |||
| self.tail = self.ptn_items[-1] | |||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | |||
| # the pattern will get additional score. | |||
| self.additional_score = 0 | |||
| def insert(self, idx, seq_len): | |||
| """ | |||
| @@ -17,10 +17,10 @@ import copy | |||
| import uuid | |||
| from typing import Dict, List, Callable, Union | |||
| from collections import OrderedDict | |||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE | |||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | |||
| from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name | |||
| from .pattern import Pattern, scope_name_mapping | |||
| from .built_in_pattern import BUILT_IN_PATTERN | |||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | |||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | |||
| from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | |||
| @@ -376,6 +376,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, | |||
| if ptn_key not in pattern: | |||
| pattern[ptn_key] = Pattern(ptn, len(found_sequence), | |||
| in_degree=in_degree, out_degree=out_degree) | |||
| if is_built_in_pattern(pattern[ptn_key]): | |||
| pattern[ptn_key].additional_score = cal_matching_score(pattern[ptn_key].ptn_length) | |||
| pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) | |||
| cur_idx = cur_idx + 1 | |||
| @@ -425,6 +427,7 @@ class SearchPath: | |||
| } | |||
| self._created_modules[self.pattern.module_name] = self.pattern | |||
| self.heuristic_v = self._heuristic_val() | |||
| self.replacement_ratio = self._repl_ratio() | |||
| self.actual_v = self._actual_val() | |||
| def _create_new_order(self): | |||
| @@ -442,6 +445,8 @@ class SearchPath: | |||
| else: | |||
| module_name = scope_name_mapping[self.pattern.pattern] | |||
| self.pattern.module_name = module_name | |||
| if is_built_in_module_name(module_name): | |||
| self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) | |||
| topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | |||
| return topo_order, inverted_index | |||
| @@ -572,7 +577,12 @@ class SearchPath: | |||
| self.graph.precursor_table[s_nd] = p_nodes | |||
| def evaluate_score(self): | |||
| """Evaluate path score.""" | |||
| """ | |||
| Evaluate path score. | |||
| Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H | |||
| = 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H | |||
| """ | |||
| return .7 * self.actual_v + .3 * self.heuristic_v | |||
| def _cal_merged_module_length(self, ptn): | |||
| @@ -594,9 +604,21 @@ class SearchPath: | |||
| return 1.0 | |||
| return max(res) | |||
| def _cal_bonus(self): | |||
| """Calculate total pattern length.""" | |||
| score = self.pattern.additional_score | |||
| for search_path in self.recursion_path: | |||
| score += search_path.pattern.additional_score | |||
| return score | |||
| def _repl_ratio(self): | |||
| """Calculate replacement ratio of current path.""" | |||
| return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() | |||
| def _actual_val(self): | |||
| """Calculate ground-truth score of the path.""" | |||
| return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() | |||
| bonus = self._cal_bonus() | |||
| return bonus * 0.1 + self.replacement_ratio * 0.9 | |||
| def __lt__(self, other): | |||
| """Override `<` operator.""" | |||