|
|
|
@@ -17,7 +17,7 @@ import copy |
|
|
|
import uuid |
|
|
|
from typing import Dict, List, Callable, Union |
|
|
|
from collections import OrderedDict |
|
|
|
from .common import context, gen_hash_key, DagGraph |
|
|
|
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE |
|
|
|
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode |
|
|
|
|
|
|
|
scope_name_mapping = {} |
|
|
|
@@ -198,18 +198,19 @@ def _get_pattern_degree(sequence: Union[OrderedDict, dict, list], |
|
|
|
tuple[int, int], in degree and out degree. |
|
|
|
""" |
|
|
|
in_node = set() |
|
|
|
out_node = set() |
|
|
|
node_in_seq = set() |
|
|
|
items = sequence if isinstance(sequence, list) else sequence.keys() |
|
|
|
for _, item in enumerate(items): |
|
|
|
for item in items: |
|
|
|
node_in_seq.add(item.name if not isinstance(item, str) else item) |
|
|
|
out_degree = 0 |
|
|
|
for item in items: |
|
|
|
item = item.name if not isinstance(item, str) else item |
|
|
|
for ipt in dag.precursor_table[item]: |
|
|
|
in_node.add(ipt) |
|
|
|
for opt in dag.successor_table[item]: |
|
|
|
out_node.add(opt) |
|
|
|
node_in_seq.add(item) |
|
|
|
if opt not in node_in_seq: |
|
|
|
out_degree += 1 |
|
|
|
in_degree = len(in_node - node_in_seq) |
|
|
|
out_degree = len(out_node - node_in_seq) |
|
|
|
return in_degree, out_degree |
|
|
|
|
|
|
|
|
|
|
|
@@ -335,6 +336,10 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, |
|
|
|
dag=dag) |
|
|
|
|
|
|
|
in_degree, out_degree = _get_pattern_degree(found_sequence, dag) |
|
|
|
if out_degree > MAX_OUT_DEGREE: |
|
|
|
cur_idx += 1 |
|
|
|
continue |
|
|
|
|
|
|
|
ptn = '->'.join(found_sequence.values()) |
|
|
|
ptn_key = f"{ptn}[{in_degree}, {out_degree}]" |
|
|
|
if ptn_key not in pattern: |
|
|
|
|