diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 28d52a5b..8d78b382 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -139,7 +139,7 @@ THIRD_PART_VERSION = { "torch": (TORCH_MIN_VER,), "onnx": (ONNX_MIN_VER,), "onnxruntime": (ONNXRUNTIME_MIN_VER,), - "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER), + "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER,), "tf2onnx": (TF2ONNX_MIN_VER,) } diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index da4f1e9d..15c0ac53 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -26,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ - ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER, TORCH_MIN_VER + ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, TORCH_MIN_VER from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console @@ -35,8 +35,6 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea BadParamError from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory - - check_common_dependency_integrity = partial(check_dependency_integrity, "onnx", "onnxruntime", "onnxoptimizer") @@ -51,7 +49,7 @@ def onnx_lib_version_satisfied(): ONNXRUNTIME_MIN_VER, ort.__version__) if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ - or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER): + or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER): return False return True diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index 5389e083..52be8e1e 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -175,6 +175,16 @@ def _multi_head_attention_with_einsum(): ] +@register_pattern("Multi-Head-Attention-TF", 2, 1) +@register_module_name("MultiHeadAttn", 2, 1) +def _multi_head_attention_tf(): + return [ + "MatMul", "Reshape", "Transpose", "MatMul", "Reshape", "Transpose", + "MatMul", "Reshape", "Transpose", "MatMul", + "Mul", "Add", "Softmax", "MatMul", "Transpose", "Reshape", "MatMul" + ] + + @register_pattern("Layer-Normalization", 1, 1) @register_module_name("LayerNorm", 1, 1) def _layer_norm(): @@ -183,6 +193,22 @@ def _layer_norm(): ] +@register_pattern("Layer-Normalization-TF", 1, 1) +@register_module_name("LayerNorm", 1, 1) +def _layer_norm_tf(): + return [ + "ReduceMean", "Sub", "Mul", "ReduceMean", "Add", "Sqrt", "Reciprocal", "Mul", "Mul", "Neg", "Mul", "Add" + ] + + +@register_pattern("Feed-Forward-Network-TF", 1, 1) +@register_module_name("FFN", 1, 1) +def _ffn_tf(): + return [ + "MatMul", "Pow", "Mul", "Add", "Mul", "Tanh", "Add", "Mul", "Mul", "MatMul" + ] + + @register_pattern("Layer-Normalization-with-cast", 1, 1) @register_module_name("LayerNorm", 1, 1) def _layer_norm_with_cast():