Browse Source

Upgrade searcher to adapt transformer.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
42c73686b4
4 changed files with 154 additions and 12 deletions
  1. +28
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  2. +118
    -11
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  3. +0
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  4. +8
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py

+ 28
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -131,6 +131,26 @@ def _multi_head_attention():
]


@register_pattern("Multi-Head-Attention-1", 2, 1)
@register_module_name("MultiHeadAttn", 2, 1)
def _multi_head_attention_1():
return [
"MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape",
"Transpose", "Reshape", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add",
"Softmax", "MatMul", "Transpose", "Reshape", "MatMul", "Add"
]


@register_pattern("Multi-Head-Attention-with-Einsum", 2, 1)
@register_module_name("MultiHeadAttn", 2, 1)
def _multi_head_attention_with_einsum():
return [
"MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape",
"Transpose", "Reshape", "Reshape", "Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax",
"MatMul", "Transpose", "Einsum", "Add"
]


@register_pattern("Layer-Normalization", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm():
@@ -161,3 +181,11 @@ def _linear():
return [
"MatMul", "Add"
]


@register_pattern("New-GeLU", 1, 1)
@register_module_name("NewGeLU", 1, 1)
def _new_gelu():
return [
"Mul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul"
]

+ 118
- 11
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -13,9 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Definition of search entry."""
from queue import PriorityQueue
from queue import PriorityQueue, Queue
from typing import Dict, List

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \
@@ -196,6 +198,109 @@ def _scope_name_deduplication(key, scope_names, memo) -> list:
return result


def _is_attn_layer(split_module):
"""
Whether the submodule is attention layer.

Attention layer is defined as: attn-add-norm-fc-gelu-fc-add-norm.

Args:
split_module (list[list[str]]): Operations list in module.

Returns:
list, found module name.
"""

def _matched(modules):
"""If the similarity score of sub_module and attention pattern is greater than 0.95, take it."""
threshold = 0.95
leaf_node = [m[-1] for m in modules]
attn_layer_ptn_with_gelu = [
"MatMul", "Add", "MatMul", "Add", "Reshape", "MatMul", "Add", "Reshape", "Transpose", "Reshape",
"Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Reshape", "MatMul",
"Add", "Add", "ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add",
"MatMul", "Add", "Div", "Erf", "Add", "Mul", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub", "Cast",
"Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]
attn_layer_ptn_with_new_gelu = [
"MatMul", "Add", "MatMul", "Add", "MatMul", "Add", "Reshape", "Transpose", "Reshape", "Reshape",
"Transpose", "Transpose", "MatMul", "Div", "Add", "Softmax", "MatMul", "Transpose", "Einsum", "Add", "Add",
"ReduceMean", "Sub", "Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add", "MatMul", "Add",
"Mul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul", "MatMul", "Add", "Add", "ReduceMean", "Sub",
"Cast", "Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Add"
]
matched = max(pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_gelu)[1],
pattern_fuzzy_matching(leaf_node, attn_layer_ptn_with_new_gelu)[1]) > threshold
return matched

candidates = Queue()
candidates.put(split_module, block=False)
while not candidates.empty():
candidate = candidates.get(block=False)
if _matched(candidate):
return candidate[0][0].split("_")[0]
cur_scope = candidate[0][1]
split_sub_module = []
for item in candidate:
# It's not necessary to scan the module which depth is 2.
if len(item) == 2:
continue
if item[1] != cur_scope:
cur_scope = item[1]
if split_sub_module:
candidates.put(split_sub_module[:], block=False)
split_sub_module.clear()
split_sub_module.append(item[1:])
if split_sub_module:
candidates.put(split_sub_module[:], block=False)
return None


def _lift_each_module(sub_module):
"""Lift each module in sub-module."""
lifted_module = []
split_module = []
cur_scope = sub_module[0].split("/")[0]
segmented_pos = 0

def _lift(modules):
nonlocal lifted_module, split_module
exceed_max_depth = max(*[len(m.split("/")) for m in modules]) > 2
if not exceed_max_depth:
for _ in range(len(split_module)):
lifted_module.append((False, 0))
return
# attn_module_name has been normalized without "_idx", only has raw module name.
attn_module_name = _is_attn_layer(split_module)
for s_md in split_module:
if attn_module_name:
md_name = [md for md in s_md if attn_module_name in md]
if md_name:
md_name = md_name[0]
attn_idx = s_md.index(md_name)
if attn_idx > 0:
lifted_module.append((True, attn_idx))
continue
lifted_module.append((False, 0))
continue
lifted_module.append((True, 0))

for i, m in enumerate(sub_module):
split_md = m.split("/")
# Find one module.
if cur_scope != split_md[0]:
_lift(sub_module[segmented_pos:i])
# Clean up.
cur_scope = split_md[0]
segmented_pos = i
split_module.clear()
split_module.append(split_md)

# Do lift on last module.
_lift(sub_module[segmented_pos:])
return lifted_module


def _retrieve_operators(module_path, module_dict):
"""
Retrieve operators from path.
@@ -208,26 +313,29 @@ def _retrieve_operators(module_path, module_dict):
str: module_name, operators in module.
"""

def _whether_to_lift(sub_module):
"""Whether to lift a scope according to its depth."""
return max(*[len(m.split("/")) for m in sub_module]) > 2

def _lift(sub_module):
"""Lift nodes upper."""
nonlocal added_module
lifted_submodule = []
record = dict()
lift_needed = _whether_to_lift(sub_module)
for m in sub_module:
# DO NOT lift on attn-add-norm-fc with GeLU-fc-add-norm.
# It's a fix pattern in Transformer model.
lift_on_each_module = _lift_each_module(sub_module)
for i, m in enumerate(sub_module):
lift_needed, lift_from = lift_on_each_module[i]
scopes = m.split("/")
if lift_needed and len(scopes) == 3:
if lift_needed and len(scopes) >= 3:
# If the scope depth is 3, like ModuleX/ModuleY/Gemm,
# then we lift ModuleY to top level.
md_name, md_idx = scopes[-2].split("_")
md_name, md_idx = scopes[-2 if lift_from == 0 else lift_from].split("_")
if record.get(md_name, -1) != md_idx:
record[md_name] = md_idx
added_module[md_name] = added_module.setdefault(md_name, -1) + 1
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}")
if lift_from != 0:
lifted_md = "/".join([f"{md_name}_{added_module.setdefault(md_name, 0)}"] + scopes[lift_from + 1:])
else:
lifted_md = f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}"
lifted_submodule.append(lifted_md)
continue
if lift_needed and len(scopes) == 2:
# If the module is required to lifted, then lift leaf node to parent.
@@ -263,7 +371,6 @@ def _build_connection(loader):
context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list

# Record the model inputs count, use it to control the search algorithm.
context.has_multi_inputs = len(loader.input_nodes) > 1
dag = DagGraph(nodes=context.node_collection.copy(),


+ 0
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -192,7 +192,6 @@ class Graph(BaseGraph, abc.ABC):
"""
for name, node in self._nodes_collection.items():
if node.in_degree == 0:
# NOTICE: what's usage of `scope`?
self._input_nodes.append(name)

if node.out_degree == 0:


+ 8
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -107,8 +107,16 @@ class OnnxSimplify:
output_nodes_name = list()
for node in self._constant_nodes:
output_nodes_name.extend(node.output)
original_outputs = [nd.name for nd in self._onnx_model.graph.output]
self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model,
feed_dict, output_nodes_name)
idx = 0
while idx < len(self._onnx_model.graph.output):
cur_opt = self._onnx_model.graph.output[idx]
if cur_opt.name not in original_outputs:
self._onnx_model.graph.output.remove(cur_opt)
continue
idx += 1

def _replace_constant_nodes(self):
"""Replace constant nodes to nodes with op_type 'Constant'."""


Loading…
Cancel
Save