Browse Source

!1169 Add 3rd party package integrity check

From: @liuchongming74
Reviewed-by: @ouwenchang,@yelihua
Signed-off-by: @yelihua
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
911b613649
2 changed files with 41 additions and 55 deletions
  1. +10
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  2. +31
    -55
      mindinsight/mindconverter/graph_based_converter/framework.py

+ 10
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -61,6 +61,16 @@ def _add_outputs_of_onnx_model(model, output_nodes: List[str]):
return model return model




def check_dependency_integrity(*packages):
"""Check dependency package integrity."""
try:
for pkg in packages:
import_module(pkg)
return True
except ImportError:
return False


def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]):
""" """
Fetch specific nodes output from onnx model. Fetch specific nodes output from onnx model.


+ 31
- 55
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -14,15 +14,14 @@
# ============================================================================== # ==============================================================================
"""Graph based scripts converter workflow.""" """Graph based scripts converter workflow."""
import os import os
import argparse
import sys import sys
from importlib import import_module from importlib import import_module
from importlib.util import find_spec from importlib.util import find_spec
from functools import partial


import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type, get_third_part_lib_validation_error_info
save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
@@ -32,25 +31,8 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory


permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

parser = argparse.ArgumentParser(
prog="MindConverter",
description="Graph based MindConverter CLI entry point (version: {})".format(
mindinsight.__version__)
)

parser.add_argument("--graph", type=str, required=True,
help="Third party framework's graph path.")
parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
help="Input shape of the model.")
parser.add_argument("--ckpt", type=str, required=False,
help="Third party framework's checkpoint path.")
parser.add_argument("--output", type=str, required=True,
help="Generated scripts output folder path.")
parser.add_argument("--report", type=str, required=False,
help="Generated reports output folder path.")
check_common_dependency_integrity = partial(check_dependency_integrity,
"onnx", "onnxruntime", "onnxoptimizer")




def onnx_lib_version_satisfied(): def onnx_lib_version_satisfied():
@@ -65,6 +47,14 @@ def onnx_lib_version_satisfied():
return True return True




def _print_error(err):
"""Print error to stdout and record it."""
log.error(err)
log_console.error("\n")
log_console.error(str(err))
log_console.error("\n")


def torch_installation_validation(func): def torch_installation_validation(func):
""" """
Validate args of func. Validate args of func.
@@ -76,27 +66,23 @@ def torch_installation_validation(func):
type, inner function. type, inner function.
""" """


def _f(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
error_info = None error_info = None
if graph_path.endswith('.onnx'): if graph_path.endswith('.onnx'):
if not onnx_satisfied():
if not onnx_satisfied() or not check_common_dependency_integrity():
error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
f"are required when using graph based scripts converter." f"are required when using graph based scripts converter."
else: else:
if not find_spec("torch") or not onnx_satisfied():
if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity("torch"):
error_info = f"PyTorch, " \ error_info = f"PyTorch, " \
f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
f"are required when using graph based scripts converter, and PyTorch version must " \ f"are required when using graph based scripts converter, and PyTorch version must " \
f"be consisted with model generation runtime." f"be consisted with model generation runtime."

if error_info: if error_info:
error = RuntimeIntegrityError(error_info)
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
_print_error(RuntimeIntegrityError(error_info))
sys.exit(0) sys.exit(0)


if not onnx_lib_version_satisfied(): if not onnx_lib_version_satisfied():
@@ -104,10 +90,7 @@ def torch_installation_validation(func):
f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} "
f"are required when using graph based scripts converter." f"are required when using graph based scripts converter."
) )
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
_print_error(error)
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
@@ -138,35 +121,28 @@ def tf_installation_validation(func):
type, inner function. type, inner function.
""" """


def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None): input_nodes: str = None, output_nodes: str = None):
not_integral_error = RuntimeIntegrityError(
f"TensorFlow, "
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
f"are required when using graph based scripts converter for TensorFlow conversion."
)
# Check whether tensorflow is installed. # Check whether tensorflow is installed.
if not _check_tf_installation() or not onnx_satisfied(): if not _check_tf_installation() or not onnx_satisfied():
error = RuntimeIntegrityError(
f"TensorFlow, "
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
f"are required when using graph based scripts converter for TensorFlow conversion."
)
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
_print_error(not_integral_error)
sys.exit(0)

if not any([check_common_dependency_integrity("tensorflow"),
check_common_dependency_integrity("tensorflow-gpu")]):
_print_error(not_integral_error)
sys.exit(0) sys.exit(0)


tf2onnx = import_module("tf2onnx") tf2onnx = import_module("tf2onnx")


if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
or not onnx_lib_version_satisfied(): or not onnx_lib_version_satisfied():
error = RuntimeIntegrityError(
f"TensorFlow, "
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
f"are required when using graph based scripts converter for TensorFlow conversion."
)
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
_print_error(not_integral_error)
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,


Loading…
Cancel
Save