| @@ -19,7 +19,9 @@ import argparse | |||
| import mindinsight | |||
| from mindinsight.mindconverter.converter import main | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER, \ | |||
| FrameworkType | |||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| @@ -198,6 +200,11 @@ class ModelFileAction(argparse.Action): | |||
| if not os.path.isfile(outfile_dir): | |||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | |||
| frame_type = get_framework_type(outfile_dir) | |||
| if frame_type == FrameworkType.UNKNOWN.value: | |||
| parser_in.error(f'{option_string} {outfile_dir} should be an valid ' | |||
| f'TensorFlow pb or PyTorch pth model file') | |||
| setattr(namespace, self.dest, outfile_dir) | |||
| @@ -18,8 +18,10 @@ import stat | |||
| from importlib import import_module | |||
| from typing import List, Tuple, Mapping | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | |||
| FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX | |||
| def is_converted(operation: str): | |||
| @@ -174,6 +176,7 @@ def get_dict_key_by_value(val, dic): | |||
| return d_key | |||
| return None | |||
| def convert_bytes_string_to_string(bytes_str): | |||
| """ | |||
| Convert a byte string to string by utf-8. | |||
| @@ -186,4 +189,23 @@ def convert_bytes_string_to_string(bytes_str): | |||
| """ | |||
| if isinstance(bytes_str, bytes): | |||
| return bytes_str.decode('utf-8') | |||
| return bytes_str | |||
| return bytes_str | |||
| def get_framework_type(model_path): | |||
| """Get framework type.""" | |||
| try: | |||
| with open(model_path, 'rb') as f: | |||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||
| framework_type = FrameworkType.PYTORCH.value | |||
| elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: | |||
| framework_type = FrameworkType.TENSORFLOW.value | |||
| else: | |||
| framework_type = FrameworkType.UNKNOWN.value | |||
| except IOError: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModelError(error_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| return framework_type | |||
| @@ -22,9 +22,9 @@ from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||
| save_code_file_and_report | |||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX | |||
| save_code_file_and_report, get_framework_type | |||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | |||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | |||
| @@ -264,25 +264,6 @@ def main_graph_base_converter(file_config): | |||
| raise error | |||
| def get_framework_type(model_path): | |||
| """Get framework type.""" | |||
| try: | |||
| with open(model_path, 'rb') as f: | |||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||
| framework_type = FrameworkType.PYTORCH.value | |||
| elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: | |||
| framework_type = FrameworkType.TENSORFLOW.value | |||
| else: | |||
| framework_type = FrameworkType.UNKNOWN.value | |||
| except IOError: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModelError(error_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| return framework_type | |||
| def check_params_exist(params: list, config): | |||
| """Check params exist.""" | |||
| miss_param_list = '' | |||