Browse Source

Add pattern, and remove max-version limit of onnxoptimizer

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
132bc150d2
3 changed files with 29 additions and 5 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +2
    -4
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +26
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -139,7 +139,7 @@ THIRD_PART_VERSION = {
"torch": (TORCH_MIN_VER,),
"onnx": (ONNX_MIN_VER,),
"onnxruntime": (ONNXRUNTIME_MIN_VER,),
"onnxoptimizer": (ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER),
"onnxoptimizer": (ONNXOPTIMIZER_MIN_VER,),
"tf2onnx": (TF2ONNX_MIN_VER,)
}



+ 2
- 4
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -26,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER, TORCH_MIN_VER
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, TORCH_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
@@ -35,8 +35,6 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea
BadParamError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory



check_common_dependency_integrity = partial(check_dependency_integrity,
"onnx", "onnxruntime", "onnxoptimizer")

@@ -51,7 +49,7 @@ def onnx_lib_version_satisfied():
ONNXRUNTIME_MIN_VER, ort.__version__)

if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER):
or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER):
return False
return True



+ 26
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -175,6 +175,16 @@ def _multi_head_attention_with_einsum():
]


@register_pattern("Multi-Head-Attention-TF", 2, 1)
@register_module_name("MultiHeadAttn", 2, 1)
def _multi_head_attention_tf():
return [
"MatMul", "Reshape", "Transpose", "MatMul", "Reshape", "Transpose",
"MatMul", "Reshape", "Transpose", "MatMul",
"Mul", "Add", "Softmax", "MatMul", "Transpose", "Reshape", "MatMul"
]


@register_pattern("Layer-Normalization", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm():
@@ -183,6 +193,22 @@ def _layer_norm():
]


@register_pattern("Layer-Normalization-TF", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm_tf():
return [
"ReduceMean", "Sub", "Mul", "ReduceMean", "Add", "Sqrt", "Reciprocal", "Mul", "Mul", "Neg", "Mul", "Add"
]


@register_pattern("Feed-Forward-Network-TF", 1, 1)
@register_module_name("FFN", 1, 1)
def _ffn_tf():
return [
"MatMul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul", "Mul", "MatMul"
]


@register_pattern("Layer-Normalization-with-cast", 1, 1)
@register_module_name("LayerNorm", 1, 1)
def _layer_norm_with_cast():


Loading…
Cancel
Save