From: @liuchongming74 Reviewed-by: @lilongfei15,@lilongfei15,@wangyue01 Signed-off-by: @wangyue01tags/v1.1.0
| @@ -182,6 +182,7 @@ class BaseConverterError(MindConverterException): | |||
| """Define error code of BaseConverterError.""" | |||
| UNKNOWN_ERROR = 0 | |||
| UNKNOWN_MODEL = 1 | |||
| PARAM_MISSING = 2 | |||
| BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| @@ -193,7 +194,7 @@ class BaseConverterError(MindConverterException): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = Exception, cls | |||
| except_source = Exception, UnknownModelError, ParamMissingError, cls | |||
| return except_source | |||
| @@ -209,6 +210,18 @@ class UnknownModelError(BaseConverterError): | |||
| return cls | |||
| class ParamMissingError(BaseConverterError): | |||
| """Define cli params missing error.""" | |||
| ERROR_CODE = BaseConverterError.ErrCode.PARAM_MISSING.value | |||
| def __init__(self, msg): | |||
| super(ParamMissingError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return cls | |||
| class GraphInitError(MindConverterException): | |||
| """The graph init fail error.""" | |||
| @@ -77,8 +77,7 @@ class Converter: | |||
| raise error | |||
| finally: | |||
| if self._report: | |||
| dest_report_file = os.path.join(report_dir, | |||
| '_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt') | |||
| dest_report_file = os.path.join(report_dir, f"report_of_{os.path.basename(infile).split('.')[0]}.txt") | |||
| with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: | |||
| file.write('\n'.join(self._report)) | |||
| logger.info("Convert report is saved in %s", dest_report_file) | |||
| @@ -180,7 +179,7 @@ def _path_split(file): | |||
| """ | |||
| file_dir, name = os.path.split(file) | |||
| if file_dir: | |||
| sep = file[len(file_dir)-1] | |||
| sep = file[len(file_dir) - 1] | |||
| if file_dir.startswith(sep): | |||
| return file.split(sep)[1:] | |||
| @@ -44,7 +44,7 @@ ONNXRUNTIME_MIN_VER = "1.5.2" | |||
| BINARY_HEADER_PYTORCH_FILE = \ | |||
| b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | |||
| TENSORFLOW_MODEL_SUFFIX = "pb" | |||
| BINARY_HEADER_PYTORCH_BITS = 32 | |||
| ARGUMENT_LENGTH_LIMIT = 512 | |||
| @@ -82,6 +82,7 @@ class InputType(Enum): | |||
| class FrameworkType(Enum): | |||
| PYTORCH = 0 | |||
| TENSORFLOW = 1 | |||
| UNKNOWN = 2 | |||
| def get_imported_module(): | |||
| @@ -24,12 +24,11 @@ import mindinsight | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||
| save_code_file_and_report | |||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | |||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -68,7 +67,7 @@ def torch_installation_validation(func): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch"): | |||
| error = RuntimeIntegrityError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "scripts converter, and PyTorch version must " | |||
| "be consisted with model generation runtime.") | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| @@ -242,6 +241,9 @@ def main_graph_base_converter(file_config): | |||
| """ | |||
| graph_path = file_config['model_file'] | |||
| frame_type = get_framework_type(graph_path) | |||
| if not file_config.get("shape"): | |||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | |||
| if frame_type == FrameworkType.PYTORCH.value: | |||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||
| sample_shape=file_config['shape'], | |||
| @@ -259,7 +261,6 @@ def main_graph_base_converter(file_config): | |||
| else: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModelError(error_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| @@ -269,8 +270,10 @@ def get_framework_type(model_path): | |||
| with open(model_path, 'rb') as f: | |||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||
| framework_type = FrameworkType.PYTORCH.value | |||
| else: | |||
| elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: | |||
| framework_type = FrameworkType.TENSORFLOW.value | |||
| else: | |||
| framework_type = FrameworkType.UNKNOWN.value | |||
| except IOError: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModelError(error_msg) | |||
| @@ -288,6 +291,4 @@ def check_params_exist(params: list, config): | |||
| miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param | |||
| if miss_param_list: | |||
| error = ParamMissError(miss_param_list) | |||
| log.error(str(error)) | |||
| raise error | |||
| raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.") | |||
| @@ -220,9 +220,9 @@ class Generator: | |||
| path = Scope.path_str_to_list(scope_path) | |||
| if len(path) < depth_control: | |||
| continue | |||
| else: # depth control within path length | |||
| module_num = path[depth_control - 1][0] | |||
| repeated_submodules_at_this_depth.add(module_num) | |||
| # depth control within path length. | |||
| module_num = path[depth_control - 1][0] | |||
| repeated_submodules_at_this_depth.add(module_num) | |||
| ret[depth_control] = repeated_submodules_at_this_depth | |||
| self._repeated_submodules = ret | |||
| @@ -249,9 +249,8 @@ class Generator: | |||
| compared_value = nd_struct.fragment.actual_args.get(base_parameter) | |||
| if compared_value == base_value: | |||
| continue | |||
| else: | |||
| formal_args.add(base_parameter) | |||
| break | |||
| formal_args.add(base_parameter) | |||
| break | |||
| return formal_args | |||
| @@ -310,8 +309,7 @@ class Generator: | |||
| for module_num in module_nums: | |||
| if module_num in checked_module: # module already checked | |||
| continue | |||
| else: | |||
| checked_module.add(module_num) | |||
| checked_module.add(module_num) | |||
| map_filtered = self.module_map_filter(module_num=module_num) | |||
| formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) | |||
| formal_args_in_each_submodule[module_num] = formal_args_in_this_module | |||
| @@ -679,12 +679,12 @@ class ModuleStruct: | |||
| if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return | |||
| # but do nothing here | |||
| continue | |||
| else: # if external, match with current module construct header x | |||
| if submodule_precursor in self.construct_header_x.values(): | |||
| local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) | |||
| md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) | |||
| else: # Extra precursor nodes, raise error | |||
| raise ValueError("Found external inputs of the submodule but the module does not have it.") | |||
| # if external, match with current module construct header x | |||
| if submodule_precursor in self.construct_header_x.values(): | |||
| local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) | |||
| md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) | |||
| else: # Extra precursor nodes, raise error | |||
| raise ValueError("Found external inputs of the submodule but the module does not have it.") | |||
| def register_node_output_to_module(self, nd_struct): | |||
| """Register nodes outputs to this module's return.""" | |||