|
|
|
@@ -14,15 +14,14 @@ |
|
|
|
# ============================================================================== |
|
|
|
"""Graph based scripts converter workflow.""" |
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import sys |
|
|
|
from importlib import import_module |
|
|
|
from importlib.util import find_spec |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
import mindinsight |
|
|
|
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext |
|
|
|
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ |
|
|
|
save_code_file_and_report, get_framework_type, get_third_part_lib_validation_error_info |
|
|
|
save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info |
|
|
|
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ |
|
|
|
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER |
|
|
|
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes |
|
|
|
@@ -32,25 +31,8 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea |
|
|
|
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError |
|
|
|
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory |
|
|
|
|
|
|
|
permissions = os.R_OK | os.W_OK | os.X_OK |
|
|
|
os.umask(permissions << 3 | permissions) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
prog="MindConverter", |
|
|
|
description="Graph based MindConverter CLI entry point (version: {})".format( |
|
|
|
mindinsight.__version__) |
|
|
|
) |
|
|
|
|
|
|
|
parser.add_argument("--graph", type=str, required=True, |
|
|
|
help="Third party framework's graph path.") |
|
|
|
parser.add_argument("--sample_shape", nargs='+', type=int, required=True, |
|
|
|
help="Input shape of the model.") |
|
|
|
parser.add_argument("--ckpt", type=str, required=False, |
|
|
|
help="Third party framework's checkpoint path.") |
|
|
|
parser.add_argument("--output", type=str, required=True, |
|
|
|
help="Generated scripts output folder path.") |
|
|
|
parser.add_argument("--report", type=str, required=False, |
|
|
|
help="Generated reports output folder path.") |
|
|
|
check_common_dependency_integrity = partial(check_dependency_integrity, |
|
|
|
"onnx", "onnxruntime", "onnxoptimizer") |
|
|
|
|
|
|
|
|
|
|
|
def onnx_lib_version_satisfied(): |
|
|
|
@@ -65,6 +47,14 @@ def onnx_lib_version_satisfied(): |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def _print_error(err): |
|
|
|
"""Print error to stdout and record it.""" |
|
|
|
log.error(err) |
|
|
|
log_console.error("\n") |
|
|
|
log_console.error(str(err)) |
|
|
|
log_console.error("\n") |
|
|
|
|
|
|
|
|
|
|
|
def torch_installation_validation(func): |
|
|
|
""" |
|
|
|
Validate args of func. |
|
|
|
@@ -76,27 +66,23 @@ def torch_installation_validation(func): |
|
|
|
type, inner function. |
|
|
|
""" |
|
|
|
|
|
|
|
def _f(graph_path: str, sample_shape: tuple, |
|
|
|
input_nodes: str, output_nodes: str, |
|
|
|
def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, |
|
|
|
output_folder: str, report_folder: str = None): |
|
|
|
# Check whether pytorch is installed. |
|
|
|
error_info = None |
|
|
|
if graph_path.endswith('.onnx'): |
|
|
|
if not onnx_satisfied(): |
|
|
|
if not onnx_satisfied() or not check_common_dependency_integrity(): |
|
|
|
error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ |
|
|
|
f"are required when using graph based scripts converter." |
|
|
|
else: |
|
|
|
if not find_spec("torch") or not onnx_satisfied(): |
|
|
|
if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity("torch"): |
|
|
|
error_info = f"PyTorch, " \ |
|
|
|
f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ |
|
|
|
f"are required when using graph based scripts converter, and PyTorch version must " \ |
|
|
|
f"be consisted with model generation runtime." |
|
|
|
|
|
|
|
if error_info: |
|
|
|
error = RuntimeIntegrityError(error_info) |
|
|
|
log.error(error) |
|
|
|
log_console.error("\n") |
|
|
|
log_console.error(str(error)) |
|
|
|
log_console.error("\n") |
|
|
|
_print_error(RuntimeIntegrityError(error_info)) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
if not onnx_lib_version_satisfied(): |
|
|
|
@@ -104,10 +90,7 @@ def torch_installation_validation(func): |
|
|
|
f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " |
|
|
|
f"are required when using graph based scripts converter." |
|
|
|
) |
|
|
|
log.error(error) |
|
|
|
log_console.error("\n") |
|
|
|
log_console.error(str(error)) |
|
|
|
log_console.error("\n") |
|
|
|
_print_error(error) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
func(graph_path=graph_path, sample_shape=sample_shape, |
|
|
|
@@ -138,35 +121,28 @@ def tf_installation_validation(func): |
|
|
|
type, inner function. |
|
|
|
""" |
|
|
|
|
|
|
|
def _f(graph_path: str, sample_shape: tuple, |
|
|
|
output_folder: str, report_folder: str = None, |
|
|
|
def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None, |
|
|
|
input_nodes: str = None, output_nodes: str = None): |
|
|
|
not_integral_error = RuntimeIntegrityError( |
|
|
|
f"TensorFlow, " |
|
|
|
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " |
|
|
|
f"are required when using graph based scripts converter for TensorFlow conversion." |
|
|
|
) |
|
|
|
# Check whether tensorflow is installed. |
|
|
|
if not _check_tf_installation() or not onnx_satisfied(): |
|
|
|
error = RuntimeIntegrityError( |
|
|
|
f"TensorFlow, " |
|
|
|
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " |
|
|
|
f"are required when using graph based scripts converter for TensorFlow conversion." |
|
|
|
) |
|
|
|
log.error(error) |
|
|
|
log_console.error("\n") |
|
|
|
log_console.error(str(error)) |
|
|
|
log_console.error("\n") |
|
|
|
_print_error(not_integral_error) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
if not any([check_common_dependency_integrity("tensorflow"), |
|
|
|
check_common_dependency_integrity("tensorflow-gpu")]): |
|
|
|
_print_error(not_integral_error) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
tf2onnx = import_module("tf2onnx") |
|
|
|
|
|
|
|
if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ |
|
|
|
or not onnx_lib_version_satisfied(): |
|
|
|
error = RuntimeIntegrityError( |
|
|
|
f"TensorFlow, " |
|
|
|
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " |
|
|
|
f"are required when using graph based scripts converter for TensorFlow conversion." |
|
|
|
) |
|
|
|
log.error(error) |
|
|
|
log_console.error("\n") |
|
|
|
log_console.error(str(error)) |
|
|
|
log_console.error("\n") |
|
|
|
_print_error(not_integral_error) |
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
func(graph_path=graph_path, sample_shape=sample_shape, |
|
|
|
|