From cd726e36ed4e5b12001b2693d172a87071d0f404 Mon Sep 17 00:00:00 2001 From: liuchongming Date: Tue, 2 Feb 2021 21:14:02 +0800 Subject: [PATCH] Add dependency integrity check. --- .../graph_based_converter/common/utils.py | 10 +++ .../graph_based_converter/framework.py | 86 +++++++------------ 2 files changed, 41 insertions(+), 55 deletions(-) diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 109dbbe2..35b85ab7 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -61,6 +61,16 @@ def _add_outputs_of_onnx_model(model, output_nodes: List[str]): return model +def check_dependency_integrity(*packages): + """Check dependency package integrity.""" + try: + for pkg in packages: + import_module(pkg) + return True + except ImportError: + return False + + def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): """ Fetch specific nodes output from onnx model. diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 8fe0cd47..ac0f41aa 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -14,15 +14,14 @@ # ============================================================================== """Graph based scripts converter workflow.""" import os -import argparse import sys from importlib import import_module from importlib.util import find_spec +from functools import partial -import mindinsight from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ - save_code_file_and_report, get_framework_type, get_third_part_lib_validation_error_info + save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes @@ -32,25 +31,8 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory -permissions = os.R_OK | os.W_OK | os.X_OK -os.umask(permissions << 3 | permissions) - -parser = argparse.ArgumentParser( - prog="MindConverter", - description="Graph based MindConverter CLI entry point (version: {})".format( - mindinsight.__version__) -) - -parser.add_argument("--graph", type=str, required=True, - help="Third party framework's graph path.") -parser.add_argument("--sample_shape", nargs='+', type=int, required=True, - help="Input shape of the model.") -parser.add_argument("--ckpt", type=str, required=False, - help="Third party framework's checkpoint path.") -parser.add_argument("--output", type=str, required=True, - help="Generated scripts output folder path.") -parser.add_argument("--report", type=str, required=False, - help="Generated reports output folder path.") +check_common_dependency_integrity = partial(check_dependency_integrity, + "onnx", "onnxruntime", "onnxoptimizer") def onnx_lib_version_satisfied(): @@ -65,6 +47,14 @@ def onnx_lib_version_satisfied(): return True +def _print_error(err): + """Print error to stdout and record it.""" + log.error(err) + log_console.error("\n") + log_console.error(str(err)) + log_console.error("\n") + + def torch_installation_validation(func): """ Validate args of func. @@ -76,27 +66,23 @@ def torch_installation_validation(func): type, inner function. """ - def _f(graph_path: str, sample_shape: tuple, - input_nodes: str, output_nodes: str, + def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): # Check whether pytorch is installed. error_info = None if graph_path.endswith('.onnx'): - if not onnx_satisfied(): + if not onnx_satisfied() or not check_common_dependency_integrity(): error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ f"are required when using graph based scripts converter." else: - if not find_spec("torch") or not onnx_satisfied(): + if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity("torch"): error_info = f"PyTorch, " \ f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ f"are required when using graph based scripts converter, and PyTorch version must " \ f"be consisted with model generation runtime." + if error_info: - error = RuntimeIntegrityError(error_info) - log.error(error) - log_console.error("\n") - log_console.error(str(error)) - log_console.error("\n") + _print_error(RuntimeIntegrityError(error_info)) sys.exit(0) if not onnx_lib_version_satisfied(): @@ -104,10 +90,7 @@ def torch_installation_validation(func): f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " f"are required when using graph based scripts converter." ) - log.error(error) - log_console.error("\n") - log_console.error(str(error)) - log_console.error("\n") + _print_error(error) sys.exit(0) func(graph_path=graph_path, sample_shape=sample_shape, @@ -138,35 +121,28 @@ def tf_installation_validation(func): type, inner function. """ - def _f(graph_path: str, sample_shape: tuple, - output_folder: str, report_folder: str = None, + def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None, input_nodes: str = None, output_nodes: str = None): + not_integral_error = RuntimeIntegrityError( + f"TensorFlow, " + f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " + f"are required when using graph based scripts converter for TensorFlow conversion." + ) # Check whether tensorflow is installed. if not _check_tf_installation() or not onnx_satisfied(): - error = RuntimeIntegrityError( - f"TensorFlow, " - f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " - f"are required when using graph based scripts converter for TensorFlow conversion." - ) - log.error(error) - log_console.error("\n") - log_console.error(str(error)) - log_console.error("\n") + _print_error(not_integral_error) + sys.exit(0) + + if not any([check_common_dependency_integrity("tensorflow"), + check_common_dependency_integrity("tensorflow-gpu")]): + _print_error(not_integral_error) sys.exit(0) tf2onnx = import_module("tf2onnx") if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ or not onnx_lib_version_satisfied(): - error = RuntimeIntegrityError( - f"TensorFlow, " - f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " - f"are required when using graph based scripts converter for TensorFlow conversion." - ) - log.error(error) - log_console.error("\n") - log_console.error(str(error)) - log_console.error("\n") + _print_error(not_integral_error) sys.exit(0) func(graph_path=graph_path, sample_shape=sample_shape,