From 94d6808552b609ae110fcc4594c41565df1f0313 Mon Sep 17 00:00:00 2001 From: liuchongming Date: Sun, 29 Nov 2020 15:46:19 +0800 Subject: [PATCH] Updating heuristic evaluation criteria to improve the effect of subgraph recognition. --- .../sub_graph_searcher/__init__.py | 4 +- .../sub_graph_searcher/built_in_pattern.py | 37 ++++++++++++- .../sub_graph_searcher/common.py | 54 ++++++++++++++----- .../sub_graph_searcher/known_module_name.py | 30 +++++++++++ .../sub_graph_searcher/pattern.py | 3 ++ .../sub_graph_searcher/search_path.py | 32 +++++++++-- 6 files changed, 139 insertions(+), 21 deletions(-) diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py index ff88c266..c4893f16 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # ============================================================================== """Searcher of scope name.""" -from .searcher import generate_scope_name - __all__ = ["generate_scope_name"] + +from .searcher import generate_scope_name diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index 0cec997a..7049b74d 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -13,11 +13,34 @@ # limitations under the License. # ============================================================================== """Introduce some standard pattern into MindConverter.""" + +__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] + +from .common import cal_matching_score from .pattern import Pattern BUILT_IN_PATTERN = dict() +def is_built_in_pattern(pattern: Pattern): + """ + Whether the module name was built-in. + + Args: + pattern (Pattern): Found pattern. + + Returns: + bool, true or false. + """ + for ptn in BUILT_IN_PATTERN: + if BUILT_IN_PATTERN[ptn].ptn_length == pattern.ptn_length and \ + BUILT_IN_PATTERN[ptn].in_degree == pattern.in_degree and \ + BUILT_IN_PATTERN[ptn].out_degree == pattern.out_degree and \ + BUILT_IN_PATTERN[ptn].ptn_items == pattern.ptn_items: + return True + return False + + def register_pattern(ptn_name, in_degree, out_degree): """ Register pattern to MindConverter. @@ -40,6 +63,7 @@ def register_pattern(ptn_name, in_degree, out_degree): BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), in_degree, out_degree, ptn_items=result) + BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length) return _reg @@ -76,4 +100,15 @@ def _convbnrelux3_convbn_add_relu(): "Conv", "BatchNormalization", "Add", "Relu"] -__all__ = ["BUILT_IN_PATTERN", "register_pattern"] +@register_pattern("UnSampling-op12", 1, 1) +def _up_sampling_in_op12(): + return [ + "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" + ] + + +@register_pattern("UpSampling-op10", 1, 1) +def _up_sampling_in_op10(): + return [ + "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" + ] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py index acb080a4..f66b2bc2 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py @@ -13,6 +13,18 @@ # limitations under the License. # ============================================================================== """Declare generic variable and functions.""" + +__all__ = ["context", + "gen_hash_key", + "DagGraph", + "MAX_OUT_DEGREE", + "cal_matching_score", + "ACCEPTABLE_RESULT_COUNT", + "MINI_FREQUENCY", + "SATISFIED_SCORE", + "MAX_ITERATION_DEPTH"] + +import math import copy import functools from collections import OrderedDict @@ -23,8 +35,21 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util MAX_OUT_DEGREE = 1 MINI_FREQUENCY = 4 MAX_ITERATION_DEPTH = 4 -SATISFIED_SCORE = 0.55 +SATISFIED_SCORE = 0.6 ACCEPTABLE_RESULT_COUNT = 16 +PTN_COVERAGE_THRESHOLD = 0.65 +# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. +IGNORE_PTN_LEN = 5 + + +def cal_matching_score(sequence_len: int): + """ + Calculate matching score. + + Args: + sequence_len (int): Pattern length. + """ + return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1 class CmpRelation: @@ -79,8 +104,8 @@ class AlgorithmContext: found_pattern = {} visited = set() beam_width = 5 - total_len = 0 MIN_FREQUENCY = 1 + total_len = 0 node_collection = None precursor_table = {} successor_table = {} @@ -120,6 +145,10 @@ class AlgorithmContext: return CmpRelation.GREATER if x[1].count < y[1].count: return CmpRelation.LESS + if x[1].additional_score > y[1].additional_score: + return CmpRelation.GREATER + if x[1].additional_score < y[1].additional_score: + return CmpRelation.LESS if x[1].ptn_length > y[1].ptn_length: return CmpRelation.GREATER if x[1].ptn_length < y[1].ptn_length: @@ -136,11 +165,19 @@ class AlgorithmContext: continue skip = False for j, (_, candidate) in enumerate(pattern_arr): - if i == j: + if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN): continue - if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]: + if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \ + and ptn.pattern in candidate.pattern: skip = True break + if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \ + and ptn.additional_score == 0 and candidate.pattern in ptn.pattern: + ratio = candidate.ptn_length / ptn.ptn_length + if ratio >= PTN_COVERAGE_THRESHOLD: + skip = True + break + if skip: continue res[key] = ptn @@ -148,12 +185,3 @@ class AlgorithmContext: context = AlgorithmContext() - -__all__ = ["context", - "gen_hash_key", - "DagGraph", - "MAX_OUT_DEGREE", - "MAX_ITERATION_DEPTH", - "SATISFIED_SCORE", - "MINI_FREQUENCY", - "ACCEPTABLE_RESULT_COUNT"] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py index 132c18ac..7947f60c 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py @@ -13,12 +13,28 @@ # limitations under the License. # ============================================================================== """Introduce some standard pattern name into MindConverter.""" + +__all__ = ["register_module_name", "is_built_in_module_name", "BUILT_IN_MODULE_NAME"] + from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern PLACEHOLDER = "PLC" BUILT_IN_MODULE_NAME = dict() +def is_built_in_module_name(module_name: str): + """ + Whether the module name was built-in. + + Args: + module_name (str): Module name. + + Returns: + bool, true or false. + """ + return module_name.split("_")[0] in BUILT_IN_MODULE_NAME + + def register_module_name(md_name: str, in_degree: int, out_degree: int): """ Register pattern to MindConverter. @@ -71,3 +87,17 @@ def _basic_conv_block_0(): def _conv_bn(): """Add basic conv block.""" return ["Conv", "BatchNormalization"] + + +@register_module_name("UnSample", 1, 1) +def _up_sampling_in_op12(): + return [ + "Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize" + ] + + +@register_module_name("UnSample", 1, 1) +def _up_sampling_in_op10(): + return [ + "Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize" + ] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index d09476b5..39dfebbd 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -31,6 +31,9 @@ class Pattern: self.out_degree = out_degree self.head = self.ptn_items[0] self.tail = self.ptn_items[-1] + # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, + # the pattern will get additional score. + self.additional_score = 0 def insert(self, idx, seq_len): """ diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 417db55e..0e39cb85 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -17,10 +17,10 @@ import copy import uuid from typing import Dict, List, Callable, Union from collections import OrderedDict -from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE -from .known_module_name import BUILT_IN_MODULE_NAME +from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score +from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name from .pattern import Pattern, scope_name_mapping -from .built_in_pattern import BUILT_IN_PATTERN +from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern from .pattern_fuzzy_matching import pattern_fuzzy_matching from ..third_party_graph.onnx_utils import OnnxNode, BaseNode @@ -376,6 +376,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, if ptn_key not in pattern: pattern[ptn_key] = Pattern(ptn, len(found_sequence), in_degree=in_degree, out_degree=out_degree) + if is_built_in_pattern(pattern[ptn_key]): + pattern[ptn_key].additional_score = cal_matching_score(pattern[ptn_key].ptn_length) pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) cur_idx = cur_idx + 1 @@ -425,6 +427,7 @@ class SearchPath: } self._created_modules[self.pattern.module_name] = self.pattern self.heuristic_v = self._heuristic_val() + self.replacement_ratio = self._repl_ratio() self.actual_v = self._actual_val() def _create_new_order(self): @@ -442,6 +445,8 @@ class SearchPath: else: module_name = scope_name_mapping[self.pattern.pattern] self.pattern.module_name = module_name + if is_built_in_module_name(module_name): + self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) return topo_order, inverted_index @@ -572,7 +577,12 @@ class SearchPath: self.graph.precursor_table[s_nd] = p_nodes def evaluate_score(self): - """Evaluate path score.""" + """ + Evaluate path score. + + Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H + = 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H + """ return .7 * self.actual_v + .3 * self.heuristic_v def _cal_merged_module_length(self, ptn): @@ -594,9 +604,21 @@ class SearchPath: return 1.0 return max(res) + def _cal_bonus(self): + """Calculate total pattern length.""" + score = self.pattern.additional_score + for search_path in self.recursion_path: + score += search_path.pattern.additional_score + return score + + def _repl_ratio(self): + """Calculate replacement ratio of current path.""" + return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() + def _actual_val(self): """Calculate ground-truth score of the path.""" - return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() + bonus = self._cal_bonus() + return bonus * 0.1 + self.replacement_ratio * 0.9 def __lt__(self, other): """Override `<` operator."""