Browse Source

Updating heuristic evaluation criteria to improve the effect of subgraph recognition.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
94d6808552
6 changed files with 139 additions and 21 deletions
  1. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py
  2. +36
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  3. +41
    -13
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  4. +30
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py
  5. +3
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  6. +27
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py View File

@@ -13,6 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Searcher of scope name.""" """Searcher of scope name."""
from .searcher import generate_scope_name

__all__ = ["generate_scope_name"] __all__ = ["generate_scope_name"]

from .searcher import generate_scope_name

+ 36
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -13,11 +13,34 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Introduce some standard pattern into MindConverter.""" """Introduce some standard pattern into MindConverter."""

__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]

from .common import cal_matching_score
from .pattern import Pattern from .pattern import Pattern


BUILT_IN_PATTERN = dict() BUILT_IN_PATTERN = dict()




def is_built_in_pattern(pattern: Pattern):
"""
Whether the module name was built-in.

Args:
pattern (Pattern): Found pattern.

Returns:
bool, true or false.
"""
for ptn in BUILT_IN_PATTERN:
if BUILT_IN_PATTERN[ptn].ptn_length == pattern.ptn_length and \
BUILT_IN_PATTERN[ptn].in_degree == pattern.in_degree and \
BUILT_IN_PATTERN[ptn].out_degree == pattern.out_degree and \
BUILT_IN_PATTERN[ptn].ptn_items == pattern.ptn_items:
return True
return False


def register_pattern(ptn_name, in_degree, out_degree): def register_pattern(ptn_name, in_degree, out_degree):
""" """
Register pattern to MindConverter. Register pattern to MindConverter.
@@ -40,6 +63,7 @@ def register_pattern(ptn_name, in_degree, out_degree):
BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result), BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result),
in_degree, out_degree, in_degree, out_degree,
ptn_items=result) ptn_items=result)
BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length)


return _reg return _reg


@@ -76,4 +100,15 @@ def _convbnrelux3_convbn_add_relu():
"Conv", "BatchNormalization", "Add", "Relu"] "Conv", "BatchNormalization", "Add", "Relu"]




__all__ = ["BUILT_IN_PATTERN", "register_pattern"]
@register_pattern("UnSampling-op12", 1, 1)
def _up_sampling_in_op12():
return [
"Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize"
]


@register_pattern("UpSampling-op10", 1, 1)
def _up_sampling_in_op10():
return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
]

+ 41
- 13
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -13,6 +13,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Declare generic variable and functions.""" """Declare generic variable and functions."""

__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE",
"cal_matching_score",
"ACCEPTABLE_RESULT_COUNT",
"MINI_FREQUENCY",
"SATISFIED_SCORE",
"MAX_ITERATION_DEPTH"]

import math
import copy import copy
import functools import functools
from collections import OrderedDict from collections import OrderedDict
@@ -23,8 +35,21 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util
MAX_OUT_DEGREE = 1 MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 4 MINI_FREQUENCY = 4
MAX_ITERATION_DEPTH = 4 MAX_ITERATION_DEPTH = 4
SATISFIED_SCORE = 0.55
SATISFIED_SCORE = 0.6
ACCEPTABLE_RESULT_COUNT = 16 ACCEPTABLE_RESULT_COUNT = 16
PTN_COVERAGE_THRESHOLD = 0.65
# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage.
IGNORE_PTN_LEN = 5


def cal_matching_score(sequence_len: int):
"""
Calculate matching score.

Args:
sequence_len (int): Pattern length.
"""
return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1




class CmpRelation: class CmpRelation:
@@ -79,8 +104,8 @@ class AlgorithmContext:
found_pattern = {} found_pattern = {}
visited = set() visited = set()
beam_width = 5 beam_width = 5
total_len = 0
MIN_FREQUENCY = 1 MIN_FREQUENCY = 1
total_len = 0
node_collection = None node_collection = None
precursor_table = {} precursor_table = {}
successor_table = {} successor_table = {}
@@ -120,6 +145,10 @@ class AlgorithmContext:
return CmpRelation.GREATER return CmpRelation.GREATER
if x[1].count < y[1].count: if x[1].count < y[1].count:
return CmpRelation.LESS return CmpRelation.LESS
if x[1].additional_score > y[1].additional_score:
return CmpRelation.GREATER
if x[1].additional_score < y[1].additional_score:
return CmpRelation.LESS
if x[1].ptn_length > y[1].ptn_length: if x[1].ptn_length > y[1].ptn_length:
return CmpRelation.GREATER return CmpRelation.GREATER
if x[1].ptn_length < y[1].ptn_length: if x[1].ptn_length < y[1].ptn_length:
@@ -136,11 +165,19 @@ class AlgorithmContext:
continue continue
skip = False skip = False
for j, (_, candidate) in enumerate(pattern_arr): for j, (_, candidate) in enumerate(pattern_arr):
if i == j:
if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN):
continue continue
if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]:
if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \
and ptn.pattern in candidate.pattern:
skip = True skip = True
break break
if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \
and ptn.additional_score == 0 and candidate.pattern in ptn.pattern:
ratio = candidate.ptn_length / ptn.ptn_length
if ratio >= PTN_COVERAGE_THRESHOLD:
skip = True
break

if skip: if skip:
continue continue
res[key] = ptn res[key] = ptn
@@ -148,12 +185,3 @@ class AlgorithmContext:




context = AlgorithmContext() context = AlgorithmContext()

__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE",
"MAX_ITERATION_DEPTH",
"SATISFIED_SCORE",
"MINI_FREQUENCY",
"ACCEPTABLE_RESULT_COUNT"]

+ 30
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py View File

@@ -13,12 +13,28 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Introduce some standard pattern name into MindConverter.""" """Introduce some standard pattern name into MindConverter."""

__all__ = ["register_module_name", "is_built_in_module_name", "BUILT_IN_MODULE_NAME"]

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern


PLACEHOLDER = "PLC" PLACEHOLDER = "PLC"
BUILT_IN_MODULE_NAME = dict() BUILT_IN_MODULE_NAME = dict()




def is_built_in_module_name(module_name: str):
"""
Whether the module name was built-in.

Args:
module_name (str): Module name.

Returns:
bool, true or false.
"""
return module_name.split("_")[0] in BUILT_IN_MODULE_NAME


def register_module_name(md_name: str, in_degree: int, out_degree: int): def register_module_name(md_name: str, in_degree: int, out_degree: int):
""" """
Register pattern to MindConverter. Register pattern to MindConverter.
@@ -71,3 +87,17 @@ def _basic_conv_block_0():
def _conv_bn(): def _conv_bn():
"""Add basic conv block.""" """Add basic conv block."""
return ["Conv", "BatchNormalization"] return ["Conv", "BatchNormalization"]


@register_module_name("UnSample", 1, 1)
def _up_sampling_in_op12():
return [
"Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize"
]


@register_module_name("UnSample", 1, 1)
def _up_sampling_in_op10():
return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
]

+ 3
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -31,6 +31,9 @@ class Pattern:
self.out_degree = out_degree self.out_degree = out_degree
self.head = self.ptn_items[0] self.head = self.ptn_items[0]
self.tail = self.ptn_items[-1] self.tail = self.ptn_items[-1]
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score.
self.additional_score = 0


def insert(self, idx, seq_len): def insert(self, idx, seq_len):
""" """


+ 27
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -17,10 +17,10 @@ import copy
import uuid import uuid
from typing import Dict, List, Callable, Union from typing import Dict, List, Callable, Union
from collections import OrderedDict from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE
from .known_module_name import BUILT_IN_MODULE_NAME
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name
from .pattern import Pattern, scope_name_mapping from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching from .pattern_fuzzy_matching import pattern_fuzzy_matching
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode from ..third_party_graph.onnx_utils import OnnxNode, BaseNode


@@ -376,6 +376,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
if ptn_key not in pattern: if ptn_key not in pattern:
pattern[ptn_key] = Pattern(ptn, len(found_sequence), pattern[ptn_key] = Pattern(ptn, len(found_sequence),
in_degree=in_degree, out_degree=out_degree) in_degree=in_degree, out_degree=out_degree)
if is_built_in_pattern(pattern[ptn_key]):
pattern[ptn_key].additional_score = cal_matching_score(pattern[ptn_key].ptn_length)


pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence))
cur_idx = cur_idx + 1 cur_idx = cur_idx + 1
@@ -425,6 +427,7 @@ class SearchPath:
} }
self._created_modules[self.pattern.module_name] = self.pattern self._created_modules[self.pattern.module_name] = self.pattern
self.heuristic_v = self._heuristic_val() self.heuristic_v = self._heuristic_val()
self.replacement_ratio = self._repl_ratio()
self.actual_v = self._actual_val() self.actual_v = self._actual_val()


def _create_new_order(self): def _create_new_order(self):
@@ -442,6 +445,8 @@ class SearchPath:
else: else:
module_name = scope_name_mapping[self.pattern.pattern] module_name = scope_name_mapping[self.pattern.pattern]
self.pattern.module_name = module_name self.pattern.module_name = module_name
if is_built_in_module_name(module_name):
self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length)
topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl)
return topo_order, inverted_index return topo_order, inverted_index


@@ -572,7 +577,12 @@ class SearchPath:
self.graph.precursor_table[s_nd] = p_nodes self.graph.precursor_table[s_nd] = p_nodes


def evaluate_score(self): def evaluate_score(self):
"""Evaluate path score."""
"""
Evaluate path score.

Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H
= 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H
"""
return .7 * self.actual_v + .3 * self.heuristic_v return .7 * self.actual_v + .3 * self.heuristic_v


def _cal_merged_module_length(self, ptn): def _cal_merged_module_length(self, ptn):
@@ -594,9 +604,21 @@ class SearchPath:
return 1.0 return 1.0
return max(res) return max(res)


def _cal_bonus(self):
"""Calculate total pattern length."""
score = self.pattern.additional_score
for search_path in self.recursion_path:
score += search_path.pattern.additional_score
return score

def _repl_ratio(self):
"""Calculate replacement ratio of current path."""
return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length()

def _actual_val(self): def _actual_val(self):
"""Calculate ground-truth score of the path.""" """Calculate ground-truth score of the path."""
return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length()
bonus = self._cal_bonus()
return bonus * 0.1 + self.replacement_ratio * 0.9


def __lt__(self, other): def __lt__(self, other):
"""Override `<` operator.""" """Override `<` operator."""


Loading…
Cancel
Save