| @@ -19,7 +19,9 @@ import argparse | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.converter import main | from mindinsight.mindconverter.converter import main | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER, \ | |||||
| FrameworkType | |||||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| @@ -198,6 +200,11 @@ class ModelFileAction(argparse.Action): | |||||
| if not os.path.isfile(outfile_dir): | if not os.path.isfile(outfile_dir): | ||||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | parser_in.error(f'{option_string} {outfile_dir} is not a file') | ||||
| frame_type = get_framework_type(outfile_dir) | |||||
| if frame_type == FrameworkType.UNKNOWN.value: | |||||
| parser_in.error(f'{option_string} {outfile_dir} should be an valid ' | |||||
| f'TensorFlow pb or PyTorch pth model file') | |||||
| setattr(namespace, self.dest, outfile_dir) | setattr(namespace, self.dest, outfile_dir) | ||||
| @@ -18,8 +18,10 @@ import stat | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from typing import List, Tuple, Mapping | from typing import List, Tuple, Mapping | ||||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP | |||||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | |||||
| FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX | |||||
| def is_converted(operation: str): | def is_converted(operation: str): | ||||
| @@ -174,6 +176,7 @@ def get_dict_key_by_value(val, dic): | |||||
| return d_key | return d_key | ||||
| return None | return None | ||||
| def convert_bytes_string_to_string(bytes_str): | def convert_bytes_string_to_string(bytes_str): | ||||
| """ | """ | ||||
| Convert a byte string to string by utf-8. | Convert a byte string to string by utf-8. | ||||
| @@ -186,4 +189,23 @@ def convert_bytes_string_to_string(bytes_str): | |||||
| """ | """ | ||||
| if isinstance(bytes_str, bytes): | if isinstance(bytes_str, bytes): | ||||
| return bytes_str.decode('utf-8') | return bytes_str.decode('utf-8') | ||||
| return bytes_str | |||||
| return bytes_str | |||||
| def get_framework_type(model_path): | |||||
| """Get framework type.""" | |||||
| try: | |||||
| with open(model_path, 'rb') as f: | |||||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||||
| framework_type = FrameworkType.PYTORCH.value | |||||
| elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: | |||||
| framework_type = FrameworkType.TENSORFLOW.value | |||||
| else: | |||||
| framework_type = FrameworkType.UNKNOWN.value | |||||
| except IOError: | |||||
| error_msg = "Get UNSUPPORTED model." | |||||
| error = UnknownModelError(error_msg) | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| return framework_type | |||||
| @@ -22,9 +22,9 @@ from importlib.util import find_spec | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | ||||
| save_code_file_and_report | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX | |||||
| save_code_file_and_report, get_framework_type | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | |||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | ||||
| @@ -264,25 +264,6 @@ def main_graph_base_converter(file_config): | |||||
| raise error | raise error | ||||
| def get_framework_type(model_path): | |||||
| """Get framework type.""" | |||||
| try: | |||||
| with open(model_path, 'rb') as f: | |||||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||||
| framework_type = FrameworkType.PYTORCH.value | |||||
| elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: | |||||
| framework_type = FrameworkType.TENSORFLOW.value | |||||
| else: | |||||
| framework_type = FrameworkType.UNKNOWN.value | |||||
| except IOError: | |||||
| error_msg = "Get UNSUPPORTED model." | |||||
| error = UnknownModelError(error_msg) | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| return framework_type | |||||
| def check_params_exist(params: list, config): | def check_params_exist(params: list, config): | ||||
| """Check params exist.""" | """Check params exist.""" | ||||
| miss_param_list = '' | miss_param_list = '' | ||||