Browse Source

Optimize sub-graph searcher.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
1afeca2015
3 changed files with 94 additions and 49 deletions
  1. +35
    -28
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  2. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  3. +57
    -19
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 35
- 28
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -33,10 +33,10 @@ from typing import List
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode


MAX_OUT_DEGREE = 1 MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 4
MAX_ITERATION_DEPTH = 4
SATISFIED_SCORE = 0.6
ACCEPTABLE_RESULT_COUNT = 16
MINI_FREQUENCY = 0.07
MAX_ITERATION_DEPTH = 16
SATISFIED_SCORE = 1.5
ACCEPTABLE_RESULT_COUNT = 32
PTN_COVERAGE_THRESHOLD = 0.65 PTN_COVERAGE_THRESHOLD = 0.65
# If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage. # If pattern length is short than `IGNORE_PTN_LEN`, then do not calculate the coverage.
IGNORE_PTN_LEN = 5 IGNORE_PTN_LEN = 5
@@ -52,6 +52,23 @@ def cal_matching_score(sequence_len: int):
return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1 return 2 / (1 + math.pow(math.e, -0.1 * sequence_len)) - 1




def _cmp(x, y):
"""Cmp function to sort pattern."""
if x[1].count > y[1].count:
return CmpRelation.GREATER
if x[1].count < y[1].count:
return CmpRelation.LESS
if x[1].additional_score > y[1].additional_score:
return CmpRelation.GREATER
if x[1].additional_score < y[1].additional_score:
return CmpRelation.LESS
if x[1].ptn_length > y[1].ptn_length:
return CmpRelation.GREATER
if x[1].ptn_length < y[1].ptn_length:
return CmpRelation.LESS
return CmpRelation.EQUAL


class CmpRelation: class CmpRelation:
"""Define cmp relation between `x` and `y`.""" """Define cmp relation between `x` and `y`."""
# When x is equal to y in logic. # When x is equal to y in logic.
@@ -138,39 +155,29 @@ class AlgorithmContext:
Returns: Returns:
OrderedDict, sorted pattern. OrderedDict, sorted pattern.
""" """

def _cmp(x, y):
"""Cmp function to sort pattern."""
if x[1].count > y[1].count:
return CmpRelation.GREATER
if x[1].count < y[1].count:
return CmpRelation.LESS
if x[1].additional_score > y[1].additional_score:
return CmpRelation.GREATER
if x[1].additional_score < y[1].additional_score:
return CmpRelation.LESS
if x[1].ptn_length > y[1].ptn_length:
return CmpRelation.GREATER
if x[1].ptn_length < y[1].ptn_length:
return CmpRelation.LESS
return CmpRelation.EQUAL

pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp), pattern_arr = sorted(pattern_arr.items(), key=functools.cmp_to_key(_cmp),
reverse=True) reverse=True)
if len(pattern_arr) > self.beam_width:
if len(pattern_arr) > AlgorithmContext.beam_width:
pattern_arr = pattern_arr[:self.beam_width] pattern_arr = pattern_arr[:self.beam_width]
res = OrderedDict() res = OrderedDict()
for i, (key, ptn) in enumerate(pattern_arr): for i, (key, ptn) in enumerate(pattern_arr):
if ptn.count <= self.MIN_FREQUENCY:
if ptn.count <= AlgorithmContext.MIN_FREQUENCY:
continue
if ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN:
res[key] = ptn
continue continue
skip = False skip = False
for j, (_, candidate) in enumerate(pattern_arr): for j, (_, candidate) in enumerate(pattern_arr):
if i == j or (ptn.additional_score > 0 and ptn.ptn_length > IGNORE_PTN_LEN):
if i == j:
continue continue
# If `ptn` is a sub-pattern of `candidate`, and `ptn` count equals to `candidate`,
# then reject the `ptn`.
if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \ if candidate.ptn_length >= ptn.ptn_length and ptn.count == candidate.count \
and ptn.pattern in candidate.pattern: and ptn.pattern in candidate.pattern:
skip = True skip = True
break break
# If `candidate` is sub-pattern of `ptn`, `candidate` has additional score,
# and `ptn` has no additional score, then calculate its replacement ratio.
if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \ if candidate.ptn_length < ptn.ptn_length and candidate.additional_score != 0 \
and ptn.additional_score == 0 and candidate.pattern in ptn.pattern: and ptn.additional_score == 0 and candidate.pattern in ptn.pattern:
ratio = candidate.ptn_length / ptn.ptn_length ratio = candidate.ptn_length / ptn.ptn_length
@@ -178,9 +185,9 @@ class AlgorithmContext:
skip = True skip = True
break break


if skip:
continue
res[key] = ptn
if not skip:
res[key] = ptn
return res return res






+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@ class Pattern:
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score. # the pattern will get additional score.
self.additional_score = 0 self.additional_score = 0
self.know_module_name = None
self.known_module_name = None


def insert(self, idx, seq_len): def insert(self, idx, seq_len):
""" """


+ 57
- 19
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -36,13 +36,23 @@ def _is_satisfied(path):
""" """
if len(path.recursion_path) > MAX_ITERATION_DEPTH: if len(path.recursion_path) > MAX_ITERATION_DEPTH:
return True return True
if not path.new_pattern or max([p.count for _, p in path.new_pattern.items()]) < MINI_FREQUENCY:
if not path.new_pattern or not any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]):
return True return True
if path.evaluate_score() > SATISFIED_SCORE: if path.evaluate_score() > SATISFIED_SCORE:
return True return True
return False return False




def is_pattern_satisfied(pattern, seq):
"""Whether a pattern is valid."""
rpl_ratio = 1.0 * pattern.count * pattern.ptn_length / len(seq.topo_order_aft_repl)
# If replacement ratio is larger than 7%,
# then take it, otherwise, reject this pattern.
if rpl_ratio >= MINI_FREQUENCY:
return True
return False


def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
init_graph, sub_graph_size: int = 2) -> List[SearchPath]: init_graph, sub_graph_size: int = 2) -> List[SearchPath]:
""" """
@@ -61,7 +71,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
sorted_pattern = context.sort_with_beam(init_pattern) sorted_pattern = context.sort_with_beam(init_pattern)
# 2. Put pattern into queue. # 2. Put pattern into queue.
queue = PriorityQueue() queue = PriorityQueue()
for _, pattern_inst in sorted_pattern.items():
for pattern_inst in sorted_pattern.values():
queue.put( queue.put(
SearchPath(pattern=pattern_inst, sequence=init_topo_order, SearchPath(pattern=pattern_inst, sequence=init_topo_order,
graph=init_graph, graph=init_graph,
@@ -70,6 +80,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
) )


available_path = [] available_path = []
deduplicate_path = set()
while not queue.empty(): while not queue.empty():
# a. replace pattern in current topo order. # a. replace pattern in current topo order.
cur_path = queue.get(block=False) cur_path = queue.get(block=False)
@@ -77,13 +88,18 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
# b. generate new pattern based on replaced topo order. # b. generate new pattern based on replaced topo order.
if _is_satisfied(cur_path): if _is_satisfied(cur_path):
available_path.append(cur_path) available_path.append(cur_path)
deduplicate_path.add(cur_path.hash_of_aft_repl)
continue continue


if len(available_path) >= ACCEPTABLE_RESULT_COUNT: if len(available_path) >= ACCEPTABLE_RESULT_COUNT:
break break


for _, cur_pattern in cur_path.new_pattern.items():
if cur_pattern.count < MINI_FREQUENCY:
for cur_pattern in cur_path.new_pattern.values():
if not is_pattern_satisfied(cur_pattern, cur_path):
if cur_path.hash_of_aft_repl in deduplicate_path:
continue
available_path.append(cur_path)
deduplicate_path.add(cur_path.hash_of_aft_repl)
continue continue
key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]", key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]",
gen_hash_key(cur_topo_order, without_module=True)]) gen_hash_key(cur_topo_order, without_module=True)])
@@ -167,13 +183,8 @@ def _scope_name_deduplication(key, scope_names, memo) -> list:
Returns: Returns:
list, renamed scope name. list, renamed scope name.
""" """
result = []
if key not in memo:
memo[key] = 0
for item in scope_names:
item = item.replace(key, f"{key}_{memo.get(key)}")
result.append(item)
memo[key] += 1
memo[key] = memo.setdefault(key, -1) + 1
result = [item.replace(key, f"{key}_{memo.get(key)}") for item in scope_names]
return result return result




@@ -188,14 +199,43 @@ def _retrieve_operators(module_path, module_dict):
Returns: Returns:
str: module_name, operators in module. str: module_name, operators in module.
""" """

def _whether_to_lift(sub_module):
"""Whether to lift a scope according to its depth."""
return max(*[len(m.split("/")) for m in sub_module]) > 2

def _lift(sub_module):
"""Lift nodes upper."""
nonlocal added_module
lifted_submodule = []
continuity_idx = -1
lift_needed = _whether_to_lift(sub_module)
for m in sub_module:
scopes = m.split("/")
if lift_needed and len(scopes) == 3:
# If the scope depth is 3, like ModuleX/ModuleY/Gemm,
# then we lift ModuleY to top level.
md_name, md_idx = scopes[-2].split("_")
if continuity_idx != int(md_idx):
continuity_idx = int(md_idx)
added_module[md_name] = added_module.setdefault(md_name, -1) + 1
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}")
continue
if lift_needed and len(scopes) == 2:
# If the module is required to lifted, then lift leaf node to parent.
lifted_submodule.append(scopes[-1])
continue
# If lift is not required, then add it directly.
lifted_submodule.append(m)
return lifted_submodule

added_module = dict() added_module = dict()
node_in_pattern = module_path.pattern.ptn_items node_in_pattern = module_path.pattern.ptn_items
node_list = [] node_list = []
for node in node_in_pattern: for node in node_in_pattern:
if module_dict.get(node): if module_dict.get(node):
node_list += _scope_name_deduplication(node,
module_dict[node],
added_module)
sub_scope = _scope_name_deduplication(node, module_dict[node], added_module)
node_list += _lift(sub_scope)
else: else:
node_list.append(node) node_list.append(node)
val = [f"{module_path.pattern.module_name}/{node}" for node in node_list] val = [f"{module_path.pattern.module_name}/{node}" for node in node_list]
@@ -231,7 +271,7 @@ def flatten_graph(graph):
Returns: Returns:
list[str], corresponding scope name. list[str], corresponding scope name.
""" """
return [f"Model/{node.op_type}" for _, node in graph.node_collection.items()]
return [f"Model/{node.op_type}" for node in graph.node_collection.values()]




def validate_topo_order_succession(): def validate_topo_order_succession():
@@ -280,9 +320,6 @@ def generate_scope_name(data_loader):
""" """
init_dag = _build_connection(data_loader) init_dag = _build_connection(data_loader)
try: try:
if not validate_topo_order_succession():
raise ValueError("Topological order is not successive.")

result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag)


@@ -294,4 +331,5 @@ def generate_scope_name(data_loader):


except (ValueError, IndexError, AttributeError, KeyError) as _: except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag) topo_order_with_scope_name_list = flatten_graph(init_dag)

return topo_order_with_scope_name_list return topo_order_with_scope_name_list

Loading…
Cancel
Save