Browse Source

Updating heuristic evaluation criteria to improve the effect of subgraph recognition.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
94d6808552
6 changed files with 139 additions and 21 deletions
  1. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py
  2. +36
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  3. +41
    -13
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  4. +30
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py
  5. +3
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  6. +27
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py View File

@@ -13,6 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Searcher of scope name."""
from .searcher import generate_scope_name

__all__ = ["generate_scope_name"]

from .searcher import generate_scope_name

+ 36
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -13,11 +13,34 @@
# limitations under the License.
# ==============================================================================
"""Introduce some standard pattern into MindConverter."""

__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]

from .common import cal_matching_score
from .pattern import Pattern

BUILT_IN_PATTERN = dict()


def is_built_in_pattern(pattern: Pattern):
"""
Whether the module name was built-in.

Args:
pattern (Pattern): Found pattern.

Returns:
bool, true or false.
"""
for ptn in BUILT_IN_PATTERN:
if BUILT_IN_PATTERN[ptn].ptn_length == pattern.ptn_length and \
BUILT_IN_PATTERN[ptn].in_degree == pattern.in_degree and \
BUILT_IN_PATTERN[ptn].out_degree == pattern.out_degree and \
BUILT_IN_PATTERN[ptn].ptn_items == pattern.ptn_items:
return True
return False


def register_pattern(ptn_name, in_degree, out_degree):
"""
Register pattern to MindConverter.
@@ -40,6 +63,7 @@ def register_pattern(ptn_name, in_degree, out_degree):
BUILT_IN_PATTERN[ptn_name] = Pattern("->".join(result), len(result),
in_degree, out_degree,
ptn_items=result)
BUILT_IN_PATTERN[ptn_name].additional_score = cal_matching_score(BUILT_IN_PATTERN[ptn_name].ptn_length)

return _reg

@@ -76,4 +100,15 @@ def _convbnrelux3_convbn_add_relu():
"Conv", "BatchNormalization", "Add", "Relu"]


__all__ = ["BUILT_IN_PATTERN", "register_pattern"]
@register_pattern("UnSampling-op12", 1, 1)
def _up_sampling_in_op12():
return [
"Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize"
]


@register_pattern("UpSampling-op10", 1, 1)
def _up_sampling_in_op10():
return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
]

+ 41
- 13
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -13,6 +13,18 @@
# limitations under the License.
# ==============================================================================
"""Declare generic variable and functions."""

__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE",
"cal_matching_score",
"ACCEPTABLE_RESULT_COUNT",
"MINI_FREQUENCY",
"SATISFIED_SCORE",
"MAX_ITERATION_DEPTH"]

import math
import copy
import functools
from collections import OrderedDict
@@ -23,8 +35,21 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_util
MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 4
MAX_ITERATION_DEPTH = 4
SATISFIED_SCORE = 0.55
SATISFIED_SCORE = 0.6
ACCEPTABLE_RESULT_COUNT = 16
PTN_COVERAGE_THRESHOLD = 0.65
# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage.
IGNORE_PTN_LEN = 5


def cal_matching_score(sequence_len: int):
"""
Calculate matching score.

Args:
sequence_len (int): Pattern length.
"""
return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1


class CmpRelation:
@@ -79,8 +104,8 @@ class AlgorithmContext:
found_pattern = {}
visited = set()
beam_width = 5
total_len = 0
MIN_FREQUENCY = 1
total_len = 0
node_collection = None
precursor_table = {}
successor_table = {}
@@ -120,6 +145,10 @@ class AlgorithmContext:
return CmpRelation.GREATER
if x[1].count < y[1].count:
return CmpRelation.LESS
if x[1].additional_score > y[1].additional_score:
return CmpRelation.GREATER
if x[1].additional_score < y[1].additional_score:
return CmpRelation.LESS
if x[1].ptn_length > y[1].ptn_length:
return CmpRelation.GREATER
if x[1].ptn_length < y[1].ptn_length:
@@ -136,11 +165,19 @@ class AlgorithmContext:
continue
skip = False
for j, (_, candidate) in enumerate(pattern_arr):
if i == j:
if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN):
continue
if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]:
if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \
and ptn.pattern in candidate.pattern:
skip = True
break
if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \
and ptn.additional_score == 0 and candidate.pattern in ptn.pattern:
ratio = candidate.ptn_length / ptn.ptn_length
if ratio >= PTN_COVERAGE_THRESHOLD:
skip = True
break

if skip:
continue
res[key] = ptn
@@ -148,12 +185,3 @@ class AlgorithmContext:


context = AlgorithmContext()

__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE",
"MAX_ITERATION_DEPTH",
"SATISFIED_SCORE",
"MINI_FREQUENCY",
"ACCEPTABLE_RESULT_COUNT"]

+ 30
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py View File

@@ -13,12 +13,28 @@
# limitations under the License.
# ==============================================================================
"""Introduce some standard pattern name into MindConverter."""

__all__ = ["register_module_name", "is_built_in_module_name", "BUILT_IN_MODULE_NAME"]

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern

PLACEHOLDER = "PLC"
BUILT_IN_MODULE_NAME = dict()


def is_built_in_module_name(module_name: str):
"""
Whether the module name was built-in.

Args:
module_name (str): Module name.

Returns:
bool, true or false.
"""
return module_name.split("_")[0] in BUILT_IN_MODULE_NAME


def register_module_name(md_name: str, in_degree: int, out_degree: int):
"""
Register pattern to MindConverter.
@@ -71,3 +87,17 @@ def _basic_conv_block_0():
def _conv_bn():
"""Add basic conv block."""
return ["Conv", "BatchNormalization"]


@register_module_name("UnSample", 1, 1)
def _up_sampling_in_op12():
return [
"Shape", "Slice", "Gather", "Cast", "Slice", "Mul", "Cast", "Concat", "Resize"
]


@register_module_name("UnSample", 1, 1)
def _up_sampling_in_op10():
return [
"Shape", "Gather", "Cast", "Slice", "Mul", "Slice", "Cast", "Cast", "Div", "Concat", "Resize"
]

+ 3
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -31,6 +31,9 @@ class Pattern:
self.out_degree = out_degree
self.head = self.ptn_items[0]
self.tail = self.ptn_items[-1]
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score.
self.additional_score = 0

def insert(self, idx, seq_len):
"""


+ 27
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -17,10 +17,10 @@ import copy
import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE
from .known_module_name import BUILT_IN_MODULE_NAME
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name
from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode

@@ -376,6 +376,8 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
if ptn_key not in pattern:
pattern[ptn_key] = Pattern(ptn, len(found_sequence),
in_degree=in_degree, out_degree=out_degree)
if is_built_in_pattern(pattern[ptn_key]):
pattern[ptn_key].additional_score = cal_matching_score(pattern[ptn_key].ptn_length)

pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence))
cur_idx = cur_idx + 1
@@ -425,6 +427,7 @@ class SearchPath:
}
self._created_modules[self.pattern.module_name] = self.pattern
self.heuristic_v = self._heuristic_val()
self.replacement_ratio = self._repl_ratio()
self.actual_v = self._actual_val()

def _create_new_order(self):
@@ -442,6 +445,8 @@ class SearchPath:
else:
module_name = scope_name_mapping[self.pattern.pattern]
self.pattern.module_name = module_name
if is_built_in_module_name(module_name):
self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length)
topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl)
return topo_order, inverted_index

@@ -572,7 +577,12 @@ class SearchPath:
self.graph.precursor_table[s_nd] = p_nodes

def evaluate_score(self):
"""Evaluate path score."""
"""
Evaluate path score.

Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H
= 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H
"""
return .7 * self.actual_v + .3 * self.heuristic_v

def _cal_merged_module_length(self, ptn):
@@ -594,9 +604,21 @@ class SearchPath:
return 1.0
return max(res)

def _cal_bonus(self):
"""Calculate total pattern length."""
score = self.pattern.additional_score
for search_path in self.recursion_path:
score += search_path.pattern.additional_score
return score

def _repl_ratio(self):
"""Calculate replacement ratio of current path."""
return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length()

def _actual_val(self):
"""Calculate ground-truth score of the path."""
return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length()
bonus = self._cal_bonus()
return bonus * 0.1 + self.replacement_ratio * 0.9

def __lt__(self, other):
"""Override `<` operator."""


Loading…
Cancel
Save