From d58a5bcbbb3b9c81f816822b3399364a90655b05 Mon Sep 17 00:00:00 2001 From: liuchongming Date: Wed, 9 Dec 2020 11:02:21 +0800 Subject: [PATCH] Add module rename function and add exception define. --- mindinsight/mindconverter/cli.py | 105 +++--- .../mindconverter/common/exceptions.py | 299 +++++++----------- .../graph_based_converter/common/utils.py | 4 +- .../graph_based_converter/framework.py | 39 ++- .../generator/args_translator.py | 1 + .../generator/module_struct.py | 19 +- .../sub_graph_searcher/pattern.py | 1 + .../sub_graph_searcher/search_path.py | 21 +- .../sub_graph_searcher/searcher.py | 22 ++ .../third_party_graph/onnx_utils.py | 4 +- .../third_party_graph/pytorch_graph_parser.py | 12 +- .../third_party_graph/tf_graph_parser.py | 8 +- tests/st/func/mindconverter/test_converter.py | 35 -- 13 files changed, 249 insertions(+), 321 deletions(-) diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index c9b6bfa2..b9943194 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -170,6 +170,9 @@ class InFileAction(argparse.Action): if not os.path.isfile(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} is not a file') + if not os.path.basename(outfile_dir).endswith("py"): + parser_in.error(f'{option_string} {outfile_dir} is not a valid python file') + setattr(namespace, self.dest, outfile_dir) @@ -282,32 +285,32 @@ class NodeAction(argparse.Action): parser = argparse.ArgumentParser( - prog='mindconverter', - description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), - allow_abbrev=False) + prog='mindconverter', + description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), + allow_abbrev=False) parser.add_argument( - '--version', - action='version', - version='%(prog)s ({})'.format(mindinsight.__version__)) + '--version', + action='version', + version='%(prog)s ({})'.format(mindinsight.__version__)) parser.add_argument( - '--in_file', - type=str, - action=InFileAction, - required=False, - default=None, - help=""" + '--in_file', + type=str, + action=InFileAction, + required=False, + default=None, + help=""" Specify path for script file to use AST schema to do script conversation. """) parser.add_argument( - '--model_file', - type=str, - action=ModelFileAction, - required=False, - help=""" + '--model_file', + type=str, + action=ModelFileAction, + required=False, + help=""" PyTorch .pth or Tensorflow .pb model file path to use graph based schema to do script generation. When `--in_file` and `--model_file` are both provided, @@ -315,12 +318,12 @@ parser.add_argument( """) parser.add_argument( - '--shape', - type=str, - action=ShapeAction, - default=None, - required=False, - help=""" + '--shape', + type=str, + action=ShapeAction, + default=None, + required=False, + help=""" Optional, expected input tensor shape of `--model_file`. It's required when use graph based schema. @@ -328,55 +331,55 @@ parser.add_argument( """) parser.add_argument( - '--input_nodes', - type=str, - action=NodeAction, - default=None, - required=False, - help=""" + '--input_nodes', + type=str, + action=NodeAction, + default=None, + required=False, + help=""" Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. Usage: --input_nodes input_1:0,input_2:0 """) parser.add_argument( - '--output_nodes', - type=str, - action=NodeAction, - default=None, - required=False, - help=""" + '--output_nodes', + type=str, + action=NodeAction, + default=None, + required=False, + help=""" Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. Usage: --output_nodes output_1:0,output_2:0 """) parser.add_argument( - '--output', - type=str, - action=OutputDirAction, - default=os.path.join(os.getcwd(), 'output'), - help=""" + '--output', + type=str, + action=OutputDirAction, + default=os.path.join(os.getcwd(), 'output'), + help=""" Optional, specify path for converted script file directory. Default output directory is `output` folder in the current working directory. """) parser.add_argument( - '--report', - type=str, - action=LogFileAction, - default=None, - help=""" + '--report', + type=str, + action=LogFileAction, + default=None, + help=""" Optional, specify report directory. Default is converted script directory. """) parser.add_argument( - '--project_path', - type=str, - action=ProjectPathAction, - required=False, - default=None, - help=""" + '--project_path', + type=str, + action=ProjectPathAction, + required=False, + default=None, + help=""" Optional, PyTorch scripts project path. If PyTorch project is not in PYTHONPATH, please assign `--project_path` when use graph based schema. diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index fd22fbf5..a3cab505 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Define custom exception.""" +import abc import sys from enum import unique from importlib import import_module @@ -23,7 +24,7 @@ from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDA from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.utils.constant import ScriptConverterErrors -from mindinsight.utils.exceptions import MindInsightException, ParamMissError +from mindinsight.utils.exceptions import MindInsightException @unique @@ -40,12 +41,16 @@ class ConverterErrors(ScriptConverterErrors): SCRIPT_GENERATE_FAIL = 9 REPORT_GENERATE_FAIL = 10 NODE_CONVERSION_ERROR = 11 + INPUT_SHAPE_ERROR = 12 + TF_RUNTIME_ERROR = 13 BASE_CONVERTER_FAIL = 000 GRAPH_INIT_FAIL = 100 TREE_CREATE_FAIL = 200 SOURCE_FILES_SAVE_FAIL = 300 GENERATOR_FAIL = 400 + SUB_GRAPH_SEARCHING_FAIL = 500 + MODEL_LOADING_FAIL = 600 class ScriptNotSupport(MindInsightException): @@ -80,7 +85,6 @@ class MindConverterException(Exception): def __init__(self, **kwargs): """Initialization of MindInsightException.""" - error = kwargs.get('error', None) user_msg = kwargs.get('user_msg', '') debug_msg = kwargs.get('debug_msg', '') @@ -97,6 +101,9 @@ class MindConverterException(Exception): def __str__(self): return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) + def __repr__(self): + return self.__str__() + def error_code(self): """" Calculate error code. @@ -109,54 +116,59 @@ class MindConverterException(Exception): Returns: str, Hex string representing the composed MindConverter error code. """ - num = 0xFFFF & self.error.value error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) return error_code - @staticmethod - def raise_from(): + @classmethod + @abc.abstractmethod + def raise_from(cls): """Raise from below exceptions.""" - return None @classmethod - def check_except_with_print_pytorch(cls, msg): - """Check except in pytorch.""" + def uniform_catcher(cls, msg): + """Uniform exception catcher.""" def decorator(func): - def _f(graph_path, sample_shape, output_folder, report_folder): + def _f(*args, **kwargs): try: - func(graph_path=graph_path, sample_shape=sample_shape, - output_folder=output_folder, report_folder=report_folder) + res = func(*args, **kwargs) except cls.raise_from() as e: error = cls(msg=msg) detail_info = f"Error detail: {str(e)}" log_console.error(str(error)) log_console.error(detail_info) log.exception(e) - sys.exit(-1) + sys.exit(0) + except ModuleNotFoundError as e: + detail_info = f"Error detail: Required package not found, please check the runtime environment." + log_console.error(str(e)) + log_console.error(detail_info) + log.exception(e) + sys.exit(0) + return res + return _f + return decorator @classmethod - def check_except_with_print_tf(cls, msg): - """Check except in tf.""" + def check_except(cls, msg): + """Check except.""" def decorator(func): - def _f(graph_path, sample_shape, - input_nodes, output_nodes, - output_folder, report_folder): + def _f(*args, **kwargs): try: - func(graph_path=graph_path, sample_shape=sample_shape, - input_nodes=input_nodes, output_nodes=output_nodes, - output_folder=output_folder, report_folder=report_folder) + output = func(*args, **kwargs) except cls.raise_from() as e: - error = cls(msg=msg) - detail_info = f"Error detail: {str(e)}" - log_console.error(str(error)) - log_console.error(detail_info) + log.error(msg) + log.exception(e) + raise cls(msg=msg) + except Exception as e: + log.error(msg) log.exception(e) - sys.exit(-1) + raise e + return output return _f @@ -170,31 +182,12 @@ class BaseConverterFail(MindConverterException): super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, user_msg=msg) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" - except_source = (UnknownModel, - ParamMissError) + except_source = Exception, cls return except_source - @classmethod - def check_except(cls, msg): - """Check except.""" - - def decorator(func): - def _f(file_config): - try: - func(file_config=file_config) - except cls.raise_from() as e: - error = cls(msg=msg) - detail_info = f"Error detail: {str(e)}" - log_console.error(str(error)) - log_console.error(detail_info) - log.exception(e) - sys.exit(-1) - return _f - return decorator - class UnknownModel(MindConverterException): """The unknown model error.""" @@ -203,6 +196,10 @@ class UnknownModel(MindConverterException): super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, user_msg=msg) + @classmethod + def raise_from(cls): + return cls + class GraphInitFail(MindConverterException): """The graph init fail error.""" @@ -211,27 +208,19 @@ class GraphInitFail(MindConverterException): super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, user_msg=kwargs.get('msg', '')) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (FileNotFoundError, ModuleNotFoundError, ModelNotSupport, + SubGraphSearchingFail, TypeError, ZeroDivisionError, - RuntimeError) + RuntimeError, + cls) return except_source - @classmethod - def check_except_pytorch(cls, msg): - """Check except for pytorch.""" - return super().check_except_with_print_pytorch(msg) - - @classmethod - def check_except_tf(cls, msg): - """Check except for tf.""" - return super().check_except_with_print_tf(msg) - class TreeCreateFail(MindConverterException): """The tree create fail.""" @@ -240,23 +229,13 @@ class TreeCreateFail(MindConverterException): super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, user_msg=msg) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (NodeInputMissing, - TreeNodeInsertFail) + TreeNodeInsertFail, cls) return except_source - @classmethod - def check_except_pytorch(cls, msg): - """Check except.""" - return super().check_except_with_print_pytorch(msg) - - @classmethod - def check_except_tf(cls, msg): - """Check except for tf.""" - return super().check_except_with_print_tf(msg) - class SourceFilesSaveFail(MindConverterException): """The source files save fail error.""" @@ -265,25 +244,15 @@ class SourceFilesSaveFail(MindConverterException): super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, user_msg=msg) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail, - IOError) + IOError, cls) return except_source - @classmethod - def check_except_pytorch(cls, msg): - """Check except.""" - return super().check_except_with_print_pytorch(msg) - - @classmethod - def check_except_tf(cls, msg): - """Check except for tf.""" - return super().check_except_with_print_tf(msg) - class ModelNotSupport(MindConverterException): """The model not support error.""" @@ -293,55 +262,32 @@ class ModelNotSupport(MindConverterException): user_msg=msg, cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (RuntimeError, + ModuleNotFoundError, ValueError, + AssertionError, TypeError, OSError, - ZeroDivisionError) + ZeroDivisionError, cls) return except_source - @classmethod - def check_except_pytorch(cls, msg): - """Check except.""" - def decorator(func): - def _f(arch, model_path, **kwargs): - try: - output = func(arch, model_path=model_path, **kwargs) - except cls.raise_from() as e: - error = cls(msg=msg) - log.error(msg) - log.exception(e) - raise error from e - return output - return _f - return decorator +class TfRuntimeError(MindConverterException): + """Catch tf runtime error.""" + + def __init__(self, msg): + super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR, + user_msg=msg, + cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) @classmethod - def check_except_tf(cls, msg): - """Check except.""" + def raise_from(cls): tf_error_module = import_module('tensorflow.python.framework.errors_impl') tf_error = getattr(tf_error_module, 'OpError') - - cls._error = cls.raise_from() + (tf_error,) - - def decorator(func): - def _f(arch, model_path, **kwargs): - try: - output = func(arch, model_path=model_path, **kwargs) - except cls._error as e: - error = cls(msg=msg) - log.error(msg) - log.exception(e) - raise error from e - return output - - return _f - - return decorator + return tf_error, ValueError, RuntimeError, cls class NodeInputMissing(MindConverterException): @@ -352,6 +298,10 @@ class NodeInputMissing(MindConverterException): user_msg=msg, cls_code=ConverterErrors.TREE_CREATE_FAIL.value) + @classmethod + def raise_from(cls): + return ValueError, IndexError, KeyError, AttributeError, cls + class TreeNodeInsertFail(MindConverterException): """The tree node create fail error.""" @@ -361,32 +311,15 @@ class TreeNodeInsertFail(MindConverterException): user_msg=msg, cls_code=ConverterErrors.TREE_CREATE_FAIL.value) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (OSError, DuplicatedNodeIdError, MultipleRootError, - NodeIDAbsentError) + NodeIDAbsentError, cls) return except_source - @classmethod - def check_except(cls, msg): - """Check except.""" - - def decorator(func): - def _f(arch, graph): - try: - output = func(arch, graph=graph) - except cls.raise_from() as e: - error = cls(msg=msg) - log.error(msg) - log.exception(e) - raise error from e - return output - return _f - return decorator - class NodeInputTypeNotSupport(MindConverterException): """The node input type NOT support error.""" @@ -396,6 +329,10 @@ class NodeInputTypeNotSupport(MindConverterException): user_msg=msg, cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) + @classmethod + def raise_from(cls): + return ValueError, TypeError, IndexError, cls + class ScriptGenerateFail(MindConverterException): """The script generate fail error.""" @@ -405,31 +342,14 @@ class ScriptGenerateFail(MindConverterException): user_msg=msg, cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" except_source = (RuntimeError, parse.ParseError, - AttributeError) + AttributeError, cls) return except_source - @classmethod - def check_except(cls, msg): - """Check except.""" - - def decorator(func): - def _f(arch, mapper): - try: - output = func(arch, mapper=mapper) - except cls.raise_from() as e: - error = cls(msg=msg) - log.error(msg) - log.exception(e) - raise error from e - return output - return _f - return decorator - class ReportGenerateFail(MindConverterException): """The report generate fail error.""" @@ -439,28 +359,24 @@ class ReportGenerateFail(MindConverterException): user_msg=msg, cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) - @staticmethod - def raise_from(): + @classmethod + def raise_from(cls): """Raise from exceptions below.""" - except_source = ZeroDivisionError - return except_source + return ZeroDivisionError, cls - @classmethod - def check_except(cls, msg): - """Check except.""" - def decorator(func): - def _f(arch, mapper): - try: - output = func(arch, mapper=mapper) - except cls.raise_from() as e: - error = cls(msg=msg) - log.error(msg) - log.exception(e) - raise error from e - return output - return _f - return decorator +class SubGraphSearchingFail(MindConverterException): + """Sub-graph searching exception.""" + + def __init__(self, msg): + super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, + cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value, + user_msg=msg) + + @classmethod + def raise_from(cls): + """Define exception in sub-graph searching module.""" + return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls class GeneratorFail(MindConverterException): @@ -470,10 +386,23 @@ class GeneratorFail(MindConverterException): super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, user_msg=msg, cls_code=ConverterErrors.GENERATOR_FAIL.value) + @classmethod def raise_from(cls): """Raise from exceptions below.""" - except_source = (ValueError, - TypeError, - cls) + except_source = (ValueError, TypeError, cls) return except_source + + +class ModelLoadingFail(MindConverterException): + """Model loading fail.""" + + def __init__(self, msg): + super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR, + cls_code=ConverterErrors.MODEL_LOADING_FAIL.value, + user_msg=msg) + + @classmethod + def raise_from(cls): + """Define exception when model loading fail.""" + return ValueError, cls diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index fd632ccf..811da4e2 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -135,9 +135,11 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): return False return True + + def get_dict_key_by_value(val, dic): """ - Return the first appeared key of a dictionay by given value. + Return the first appeared key of a dictionary by given value. Args: val (Any): Value of the key. diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 1971fc1e..cedc2647 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -16,6 +16,7 @@ import os import re import argparse +import sys from importlib import import_module from importlib.util import find_spec @@ -25,12 +26,11 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper -from mindinsight.mindconverter.common.log import logger as log +from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ - BaseConverterFail, UnknownModel + BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError from mindinsight.utils.exceptions import ParamMissError - permissions = os.R_OK | os.W_OK | os.X_OK os.umask(permissions << 3 | permissions) @@ -71,8 +71,10 @@ def torch_installation_validation(func): "scripts converter, and PyTorch vision must " "be consisted with model generation runtime.") log.error(str(error)) - log.exception(error) - raise error + detail_info = f"Error detail: {str(error)}" + log_console.error(str(error)) + log_console.error(detail_info) + sys.exit(0) func(graph_path=graph_path, sample_shape=sample_shape, output_folder=output_folder, report_folder=report_folder) @@ -103,7 +105,10 @@ def tf_installation_validation(func): f"based scripts converter for TensorFlow conversion." ) log.error(str(error)) - raise error + detail_info = f"Error detail: {str(error)}" + log_console.error(str(error)) + log_console.error(detail_info) + sys.exit(0) onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") ort = import_module("onnxruntime") @@ -117,7 +122,10 @@ def tf_installation_validation(func): f"based scripts converter for TensorFlow conversion." ) log.error(str(error)) - raise error + detail_info = f"Error detail: {str(error)}" + log_console.error(str(error)) + log_console.error(detail_info) + sys.exit(0) func(graph_path=graph_path, sample_shape=sample_shape, output_folder=output_folder, report_folder=report_folder, @@ -142,9 +150,10 @@ def _extract_model_name(model_path): @torch_installation_validation -@GraphInitFail.check_except_pytorch("Error occurred when init graph object.") -@TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") -@SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") +@GraphInitFail.uniform_catcher("Error occurred when init graph object.") +@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") +@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") +@GeneratorFail.uniform_catcher("Error occurred when generate code.") def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None): """ @@ -176,9 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, @tf_installation_validation -@GraphInitFail.check_except_tf("Error occurred when init graph object.") -@TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") -@SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") +@GraphInitFail.uniform_catcher("Error occurred when init graph object.") +@TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.") +@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") +@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") +@GeneratorFail.uniform_catcher("Error occurred when generate code.") def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): @@ -210,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) -@BaseConverterFail.check_except("Failed to start base converter.") +@BaseConverterFail.uniform_catcher("Failed to start base converter.") def main_graph_base_converter(file_config): """ The entrance for converter, script files will be converted. diff --git a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py index 3219527b..082596e0 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py @@ -201,6 +201,7 @@ class ArgsTranslation: class ArgsTranslationHelper: """Define operations related to ArgsTranslation instances.""" + @staticmethod def find_formal_args_in_modules(args_translators): """ diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index d26149af..69e7728d 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -541,7 +541,7 @@ class ModuleStruct: for output in output_list: (provider_succ, provider_closet_opt_var) = output if provider_closet_opt_var in struct.matched_inputs: - continue # skip repeat + continue # skip repeat if provider_succ == struct.onnx_name: struct.matched_inputs.append(provider_closet_opt_var) @@ -695,20 +695,5 @@ class ModuleStruct: """Register submodule outputs to this module's return.""" submodule_returns = md_struct.get_returned_opt_var_name() submodule_opt_var_name = md_struct.ms_opt_var_name - for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: + for (submodule_ext_succ, _, ith_output) in submodule_returns: self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) - # edit external succ 's inputs in parent module - ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) - ext_node_parent = ext_node.parent_module_struct - while ext_node_parent != self.parent_module_struct: - ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name - ext_node_parent = ext_node_parent.parent_module_struct - - # need find the prec_name? - for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): - if isinstance(opt_var_name, str): - if opt_var_name == opt_var_name_in_this_module: - ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) - if isinstance(opt_var_name, tuple): - if opt_var_name[0] == opt_var_name_in_this_module: - ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index 39dfebbd..27ca4426 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -34,6 +34,7 @@ class Pattern: # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # the pattern will get additional score. self.additional_score = 0 + self.know_module_name = None def insert(self, idx, seq_len): """ diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 6088edec..bc573ae8 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -18,7 +18,7 @@ import uuid from typing import Dict, List, Callable, Union from collections import OrderedDict from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score -from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name +from .known_module_name import BUILT_IN_MODULE_NAME from .pattern import Pattern, scope_name_mapping from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern from .pattern_fuzzy_matching import pattern_fuzzy_matching @@ -85,13 +85,15 @@ def _is_valid_pattern(pattern, dag): return True -def generate_module_name(pattern): +def match_known_module_name(pattern): """ - Generate module name. + Matching with know module name. Args: pattern (Pattern): To be replaced pattern. + Returns: + str, matched module name, return None if not matched. """ matched_result = [] for ptn, module_name in BUILT_IN_MODULE_NAME.items(): @@ -109,7 +111,11 @@ def generate_module_name(pattern): module_name = f"{module_name}{used_module_name[pattern.pattern]}" used_module_name[pattern.pattern] += 1 return module_name + return None + +def generate_module_name(): + """Generate module name.""" global global_idx name = f"Module{global_idx}" global_idx += 1 @@ -439,13 +445,16 @@ class SearchPath: to recover the sequence. """ if self.pattern.pattern not in scope_name_mapping: - module_name = generate_module_name(self.pattern) - scope_name_mapping[self.pattern.pattern] = module_name + module_name = generate_module_name() + known_module_name = match_known_module_name(self.pattern) + scope_name_mapping[self.pattern] = module_name module_name_to_src[module_name] = self.pattern.pattern else: module_name = scope_name_mapping[self.pattern.pattern] + known_module_name = module_name_to_src[module_name].known_module_name self.pattern.module_name = module_name - if is_built_in_module_name(module_name): + self.pattern.known_module_name = known_module_name + if known_module_name: self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) return topo_order, inverted_index diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index efa307f5..a9211f76 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -18,8 +18,10 @@ from typing import Dict, List from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE +from ..common.global_context import GlobalContext from ..third_party_graph.onnx_utils import BaseNode from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern +from ...common.exceptions import SubGraphSearchingFail def _is_satisfied(path): @@ -249,6 +251,23 @@ def validate_topo_order_succession(): return True +def _add_known_module_name(search_path): + """ + Add known module name to GlobalContext. + + Args: + search_path (SearchPath): Search path. + + """ + ctx = GlobalContext() + if search_path.pattern.known_module_name: + ctx.known_module_name[search_path.pattern.module_name] = search_path.pattern.known_module_name + for it in search_path.recursion_path: + if it.pattern.known_module_name: + ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name + + +@SubGraphSearchingFail.check_except("Sub-Graph searching fail.") def generate_scope_name(data_loader): """ Generate scope name according to computation graph. @@ -270,6 +289,9 @@ def generate_scope_name(data_loader): if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): topo_order_with_scope_name_list = flatten_graph(init_dag) + if result: + _add_known_module_name(result) + except (ValueError, IndexError, AttributeError, KeyError) as _: topo_order_with_scope_name_list = flatten_graph(init_dag) return topo_order_with_scope_name_list diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index f76a371a..87d6fa7b 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL -from ...common.exceptions import GraphInitFail, ModelNotSupport +from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): @@ -308,7 +308,7 @@ class OnnxDataLoader: w = int(match.group('w')) c = int(match.group('c')) if [h, w, c] != list(self.graph_input_shape)[1:4]: - raise ValueError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") + raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") return True return False diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py index a6c3a47e..8b2d1cee 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py @@ -25,7 +25,9 @@ class PyTorchGraphParser(GraphParser): """Define pytorch graph parser.""" @classmethod - @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") + @ModelNotSupport.check_except( + "Error occurs in loading model, please check your model or runtime environment integrity." + ) def parse(cls, model_path: str, **kwargs): """ Parser pytorch graph. @@ -50,11 +52,9 @@ class PyTorchGraphParser(GraphParser): else: model = torch.load(f=model_path, map_location="cpu") except ModuleNotFoundError: - error_msg = \ - "Cannot find model scripts in system path, " \ - "set `--project_path` to the path of model scripts folder correctly." + error_msg = "Cannot find model scripts in system path, " \ + "set `--project_path` to the path of model scripts folder correctly." error = ModuleNotFoundError(error_msg) - log.error(str(error)) - raise error from None + raise error return model diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py index bfa3fff3..fea0d038 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py @@ -25,7 +25,9 @@ class TFGraphParser(GraphParser): """Define TF graph parser.""" @classmethod - @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") + @ModelNotSupport.check_except( + "Error occurs in loading model, please check your model or runtime environment integrity." + ) def parse(cls, model_path: str, **kwargs): """ Parse TF Computational Graph File (.pb) @@ -36,7 +38,6 @@ class TFGraphParser(GraphParser): Returns: object, ONNX model. """ - onnx_utils = import_module( "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") @@ -50,6 +51,5 @@ class TFGraphParser(GraphParser): model = convert_tf_graph_to_onnx(model_path, model_inputs=tf_input_nodes, - model_outputs=tf_output_nodes, - ) + model_outputs=tf_output_nodes) return model diff --git a/tests/st/func/mindconverter/test_converter.py b/tests/st/func/mindconverter/test_converter.py index 699b44f0..e5c804ad 100644 --- a/tests/st/func/mindconverter/test_converter.py +++ b/tests/st/func/mindconverter/test_converter.py @@ -21,13 +21,10 @@ Usage: """ import difflib import os -import re import sys - import pytest from mindinsight.mindconverter.converter import main -from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter @pytest.mark.usefixtures('create_output_dir') @@ -82,35 +79,3 @@ class TestConverter: converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) assert converted_ratio >= 80 - - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend_training - @pytest.mark.platform_x86_gpu_training - @pytest.mark.platform_x86_ascend_training - @pytest.mark.platform_x86_cpu - @pytest.mark.env_single - def test_main_graph_based_converter(self, output): - """Test main graph based converter.""" - pytorch_filename = "resnet50.pth" - expected_model_filename = "resnet50.py" - expected_report_filename = "report_of_resnet50.txt" - file_config = { - 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), - 'shape': (1, 3, 224, 224), - 'outfile_dir': output, - 'report_dir': output - } - with pytest.raises(ValueError) as e: - main_graph_base_converter(file_config=file_config) - - assert os.path.isfile(os.path.join(output, expected_model_filename)) - assert os.path.isfile(os.path.join(output, expected_report_filename)) - - with open(os.path.join(output, expected_report_filename)) as converted_r: - converted_report = converted_r.readlines() - converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) - - assert converted_rate[0] == '100.00%' - - exec_msg = e.value.args[0] - assert exec_msg == "torch.__spec__ is None"