Browse Source

!803 Disable multi-output pattern and fix bugs in hash func

Merge pull request !803 from 刘崇鸣/disable_multi_opt
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6bfb5b3e6d
3 changed files with 17 additions and 10 deletions
  1. +4
    -3
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  2. +2
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  3. +11
    -6
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py

+ 4
- 3
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -619,11 +619,12 @@ class HierarchicalTree(Tree):
cur_nd = self.get_node(p_nd[0])
return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd))

def hash_key(self, node):
def hash_key(self, node, depth: int = 0):
"""
Generate hash key for each node.

Args:
depth (int): Recursion depth.
node (Node): Node.

Returns:
@@ -633,12 +634,12 @@ class HierarchicalTree(Tree):
for s in node.successors(self.tree_identifier):
cur_nd = self.get_node(s)
if cur_nd.data.hash_key:
scsr_topo_order.append(cur_nd.data.hash_key)
scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]")
continue
if cur_nd.data.node_type in {NodeType.MODULE.value,
NodeType.FUNC.value,
NodeType.CLASS.value}:
scsr_topo_order.append(self.hash_key(cur_nd))
scsr_topo_order.append(self.hash_key(cur_nd, depth + 1))
continue
unique_key = "->".join(scsr_topo_order)
node.data.hash_key = unique_key


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -20,6 +20,8 @@ from typing import List

from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode

MAX_OUT_DEGREE = 1


class CmpRelation:
"""Define cmp relation between `x` and `y`."""
@@ -123,7 +125,6 @@ class AlgorithmContext:

context = AlgorithmContext()


__all__ = ["context",
"gen_hash_key",
"DagGraph"]

+ 11
- 6
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -17,7 +17,7 @@ import copy
import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode

scope_name_mapping = {}
@@ -198,18 +198,19 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list],
tuple[int, int], in degree and out degree.
"""
in_node = set()
out_node = set()
node_in_seq = set()
items = sequence if isinstance(sequence, list) else sequence.keys()
for _, item in enumerate(items):
for item in items:
node_in_seq.add(item.name if not isinstance(item, str) else item)
out_degree = 0
for item in items:
item = item.name if not isinstance(item, str) else item
for ipt in dag.precursor_table[item]:
in_node.add(ipt)
for opt in dag.successor_table[item]:
out_node.add(opt)
node_in_seq.add(item)
if opt not in node_in_seq:
out_degree += 1
in_degree = len(in_node - node_in_seq)
out_degree = len(out_node - node_in_seq)
return in_degree, out_degree


@@ -335,6 +336,10 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
dag=dag)

in_degree, out_degree = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE:
cur_idx += 1
continue

ptn = '->'.join(found_sequence.values())
ptn_key = f"{ptn}[{in_degree}, {out_degree}]"
if ptn_key not in pattern:


Loading…
Cancel
Save