diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index aa4e93f4..44a1e07f 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -19,7 +19,9 @@ import argparse import mindinsight from mindinsight.mindconverter.converter import main -from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER +from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type +from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER, \ + FrameworkType from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console @@ -198,6 +200,11 @@ class ModelFileAction(argparse.Action): if not os.path.isfile(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} is not a file') + frame_type = get_framework_type(outfile_dir) + if frame_type == FrameworkType.UNKNOWN.value: + parser_in.error(f'{option_string} {outfile_dir} should be an valid ' + f'TensorFlow pb or PyTorch pth model file') + setattr(namespace, self.dest, outfile_dir) diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index ff26e276..4849d860 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -18,8 +18,10 @@ import stat from importlib import import_module from typing import List, Tuple, Mapping -from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError -from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP +from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError +from mindinsight.mindconverter.common.log import logger as log +from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ + FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX def is_converted(operation: str): @@ -174,6 +176,7 @@ def get_dict_key_by_value(val, dic): return d_key return None + def convert_bytes_string_to_string(bytes_str): """ Convert a byte string to string by utf-8. @@ -186,4 +189,23 @@ def convert_bytes_string_to_string(bytes_str): """ if isinstance(bytes_str, bytes): return bytes_str.decode('utf-8') - return bytes_str + return bytes_str + + +def get_framework_type(model_path): + """Get framework type.""" + try: + with open(model_path, 'rb') as f: + if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: + framework_type = FrameworkType.PYTORCH.value + elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: + framework_type = FrameworkType.TENSORFLOW.value + else: + framework_type = FrameworkType.UNKNOWN.value + except IOError: + error_msg = "Get UNSUPPORTED model." + error = UnknownModelError(error_msg) + log.error(str(error)) + raise error + + return framework_type diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 9e17ed2d..aeb9099a 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -22,9 +22,9 @@ from importlib.util import find_spec import mindinsight from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ - save_code_file_and_report -from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ - BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX + save_code_file_and_report, get_framework_type +from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ + ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ @@ -264,25 +264,6 @@ def main_graph_base_converter(file_config): raise error -def get_framework_type(model_path): - """Get framework type.""" - try: - with open(model_path, 'rb') as f: - if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: - framework_type = FrameworkType.PYTORCH.value - elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: - framework_type = FrameworkType.TENSORFLOW.value - else: - framework_type = FrameworkType.UNKNOWN.value - except IOError: - error_msg = "Get UNSUPPORTED model." - error = UnknownModelError(error_msg) - log.error(str(error)) - raise error - - return framework_type - - def check_params_exist(params: list, config): """Check params exist.""" miss_param_list = ''