| @@ -170,6 +170,9 @@ class InFileAction(argparse.Action): | |||
| if not os.path.isfile(outfile_dir): | |||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | |||
| if not os.path.basename(outfile_dir).endswith("py"): | |||
| parser_in.error(f'{option_string} {outfile_dir} is not a valid python file') | |||
| setattr(namespace, self.dest, outfile_dir) | |||
| @@ -282,32 +285,32 @@ class NodeAction(argparse.Action): | |||
| parser = argparse.ArgumentParser( | |||
| prog='mindconverter', | |||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), | |||
| allow_abbrev=False) | |||
| prog='mindconverter', | |||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), | |||
| allow_abbrev=False) | |||
| parser.add_argument( | |||
| '--version', | |||
| action='version', | |||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||
| '--version', | |||
| action='version', | |||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||
| parser.add_argument( | |||
| '--in_file', | |||
| type=str, | |||
| action=InFileAction, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| '--in_file', | |||
| type=str, | |||
| action=InFileAction, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| Specify path for script file to use AST schema to | |||
| do script conversation. | |||
| """) | |||
| parser.add_argument( | |||
| '--model_file', | |||
| type=str, | |||
| action=ModelFileAction, | |||
| required=False, | |||
| help=""" | |||
| '--model_file', | |||
| type=str, | |||
| action=ModelFileAction, | |||
| required=False, | |||
| help=""" | |||
| PyTorch .pth or Tensorflow .pb model file path to use graph | |||
| based schema to do script generation. When | |||
| `--in_file` and `--model_file` are both provided, | |||
| @@ -315,12 +318,12 @@ parser.add_argument( | |||
| """) | |||
| parser.add_argument( | |||
| '--shape', | |||
| type=str, | |||
| action=ShapeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| '--shape', | |||
| type=str, | |||
| action=ShapeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, expected input tensor shape of | |||
| `--model_file`. It's required when use graph based | |||
| schema. | |||
| @@ -328,55 +331,55 @@ parser.add_argument( | |||
| """) | |||
| parser.add_argument( | |||
| '--input_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| '--input_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. | |||
| Usage: --input_nodes input_1:0,input_2:0 | |||
| """) | |||
| parser.add_argument( | |||
| '--output_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| '--output_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. | |||
| Usage: --output_nodes output_1:0,output_2:0 | |||
| """) | |||
| parser.add_argument( | |||
| '--output', | |||
| type=str, | |||
| action=OutputDirAction, | |||
| default=os.path.join(os.getcwd(), 'output'), | |||
| help=""" | |||
| '--output', | |||
| type=str, | |||
| action=OutputDirAction, | |||
| default=os.path.join(os.getcwd(), 'output'), | |||
| help=""" | |||
| Optional, specify path for converted script file | |||
| directory. Default output directory is `output` folder | |||
| in the current working directory. | |||
| """) | |||
| parser.add_argument( | |||
| '--report', | |||
| type=str, | |||
| action=LogFileAction, | |||
| default=None, | |||
| help=""" | |||
| '--report', | |||
| type=str, | |||
| action=LogFileAction, | |||
| default=None, | |||
| help=""" | |||
| Optional, specify report directory. Default is | |||
| converted script directory. | |||
| """) | |||
| parser.add_argument( | |||
| '--project_path', | |||
| type=str, | |||
| action=ProjectPathAction, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| '--project_path', | |||
| type=str, | |||
| action=ProjectPathAction, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| Optional, PyTorch scripts project path. If PyTorch | |||
| project is not in PYTHONPATH, please assign | |||
| `--project_path` when use graph based schema. | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define custom exception.""" | |||
| import abc | |||
| import sys | |||
| from enum import unique | |||
| from importlib import import_module | |||
| @@ -23,7 +24,7 @@ from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDA | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.utils.constant import ScriptConverterErrors | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamMissError | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| @unique | |||
| @@ -40,12 +41,16 @@ class ConverterErrors(ScriptConverterErrors): | |||
| SCRIPT_GENERATE_FAIL = 9 | |||
| REPORT_GENERATE_FAIL = 10 | |||
| NODE_CONVERSION_ERROR = 11 | |||
| INPUT_SHAPE_ERROR = 12 | |||
| TF_RUNTIME_ERROR = 13 | |||
| BASE_CONVERTER_FAIL = 000 | |||
| GRAPH_INIT_FAIL = 100 | |||
| TREE_CREATE_FAIL = 200 | |||
| SOURCE_FILES_SAVE_FAIL = 300 | |||
| GENERATOR_FAIL = 400 | |||
| SUB_GRAPH_SEARCHING_FAIL = 500 | |||
| MODEL_LOADING_FAIL = 600 | |||
| class ScriptNotSupport(MindInsightException): | |||
| @@ -80,7 +85,6 @@ class MindConverterException(Exception): | |||
| def __init__(self, **kwargs): | |||
| """Initialization of MindInsightException.""" | |||
| error = kwargs.get('error', None) | |||
| user_msg = kwargs.get('user_msg', '') | |||
| debug_msg = kwargs.get('debug_msg', '') | |||
| @@ -97,6 +101,9 @@ class MindConverterException(Exception): | |||
| def __str__(self): | |||
| return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | |||
| def __repr__(self): | |||
| return self.__str__() | |||
| def error_code(self): | |||
| """" | |||
| Calculate error code. | |||
| @@ -109,54 +116,59 @@ class MindConverterException(Exception): | |||
| Returns: | |||
| str, Hex string representing the composed MindConverter error code. | |||
| """ | |||
| num = 0xFFFF & self.error.value | |||
| error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | |||
| return error_code | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def raise_from(cls): | |||
| """Raise from below exceptions.""" | |||
| return None | |||
| @classmethod | |||
| def check_except_with_print_pytorch(cls, msg): | |||
| """Check except in pytorch.""" | |||
| def uniform_catcher(cls, msg): | |||
| """Uniform exception catcher.""" | |||
| def decorator(func): | |||
| def _f(graph_path, sample_shape, output_folder, report_folder): | |||
| def _f(*args, **kwargs): | |||
| try: | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder) | |||
| res = func(*args, **kwargs) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| detail_info = f"Error detail: {str(e)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log.exception(e) | |||
| sys.exit(-1) | |||
| sys.exit(0) | |||
| except ModuleNotFoundError as e: | |||
| detail_info = f"Error detail: Required package not found, please check the runtime environment." | |||
| log_console.error(str(e)) | |||
| log_console.error(detail_info) | |||
| log.exception(e) | |||
| sys.exit(0) | |||
| return res | |||
| return _f | |||
| return decorator | |||
| @classmethod | |||
| def check_except_with_print_tf(cls, msg): | |||
| """Check except in tf.""" | |||
| def check_except(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(graph_path, sample_shape, | |||
| input_nodes, output_nodes, | |||
| output_folder, report_folder): | |||
| def _f(*args, **kwargs): | |||
| try: | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes, | |||
| output_folder=output_folder, report_folder=report_folder) | |||
| output = func(*args, **kwargs) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| detail_info = f"Error detail: {str(e)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise cls(msg=msg) | |||
| except Exception as e: | |||
| log.error(msg) | |||
| log.exception(e) | |||
| sys.exit(-1) | |||
| raise e | |||
| return output | |||
| return _f | |||
| @@ -170,31 +182,12 @@ class BaseConverterFail(MindConverterException): | |||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | |||
| user_msg=msg) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (UnknownModel, | |||
| ParamMissError) | |||
| except_source = Exception, cls | |||
| return except_source | |||
| @classmethod | |||
| def check_except(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(file_config): | |||
| try: | |||
| func(file_config=file_config) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| detail_info = f"Error detail: {str(e)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log.exception(e) | |||
| sys.exit(-1) | |||
| return _f | |||
| return decorator | |||
| class UnknownModel(MindConverterException): | |||
| """The unknown model error.""" | |||
| @@ -203,6 +196,10 @@ class UnknownModel(MindConverterException): | |||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | |||
| user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return cls | |||
| class GraphInitFail(MindConverterException): | |||
| """The graph init fail error.""" | |||
| @@ -211,27 +208,19 @@ class GraphInitFail(MindConverterException): | |||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | |||
| user_msg=kwargs.get('msg', '')) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (FileNotFoundError, | |||
| ModuleNotFoundError, | |||
| ModelNotSupport, | |||
| SubGraphSearchingFail, | |||
| TypeError, | |||
| ZeroDivisionError, | |||
| RuntimeError) | |||
| RuntimeError, | |||
| cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except_pytorch(cls, msg): | |||
| """Check except for pytorch.""" | |||
| return super().check_except_with_print_pytorch(msg) | |||
| @classmethod | |||
| def check_except_tf(cls, msg): | |||
| """Check except for tf.""" | |||
| return super().check_except_with_print_tf(msg) | |||
| class TreeCreateFail(MindConverterException): | |||
| """The tree create fail.""" | |||
| @@ -240,23 +229,13 @@ class TreeCreateFail(MindConverterException): | |||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | |||
| user_msg=msg) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (NodeInputMissing, | |||
| TreeNodeInsertFail) | |||
| TreeNodeInsertFail, cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except_pytorch(cls, msg): | |||
| """Check except.""" | |||
| return super().check_except_with_print_pytorch(msg) | |||
| @classmethod | |||
| def check_except_tf(cls, msg): | |||
| """Check except for tf.""" | |||
| return super().check_except_with_print_tf(msg) | |||
| class SourceFilesSaveFail(MindConverterException): | |||
| """The source files save fail error.""" | |||
| @@ -265,25 +244,15 @@ class SourceFilesSaveFail(MindConverterException): | |||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | |||
| user_msg=msg) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (NodeInputTypeNotSupport, | |||
| ScriptGenerateFail, | |||
| ReportGenerateFail, | |||
| IOError) | |||
| IOError, cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except_pytorch(cls, msg): | |||
| """Check except.""" | |||
| return super().check_except_with_print_pytorch(msg) | |||
| @classmethod | |||
| def check_except_tf(cls, msg): | |||
| """Check except for tf.""" | |||
| return super().check_except_with_print_tf(msg) | |||
| class ModelNotSupport(MindConverterException): | |||
| """The model not support error.""" | |||
| @@ -293,55 +262,32 @@ class ModelNotSupport(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (RuntimeError, | |||
| ModuleNotFoundError, | |||
| ValueError, | |||
| AssertionError, | |||
| TypeError, | |||
| OSError, | |||
| ZeroDivisionError) | |||
| ZeroDivisionError, cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except_pytorch(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(arch, model_path, **kwargs): | |||
| try: | |||
| output = func(arch, model_path=model_path, **kwargs) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class TfRuntimeError(MindConverterException): | |||
| """Catch tf runtime error.""" | |||
| def __init__(self, msg): | |||
| super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||
| @classmethod | |||
| def check_except_tf(cls, msg): | |||
| """Check except.""" | |||
| def raise_from(cls): | |||
| tf_error_module = import_module('tensorflow.python.framework.errors_impl') | |||
| tf_error = getattr(tf_error_module, 'OpError') | |||
| cls._error = cls.raise_from() + (tf_error,) | |||
| def decorator(func): | |||
| def _f(arch, model_path, **kwargs): | |||
| try: | |||
| output = func(arch, model_path=model_path, **kwargs) | |||
| except cls._error as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| return tf_error, ValueError, RuntimeError, cls | |||
| class NodeInputMissing(MindConverterException): | |||
| @@ -352,6 +298,10 @@ class NodeInputMissing(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return ValueError, IndexError, KeyError, AttributeError, cls | |||
| class TreeNodeInsertFail(MindConverterException): | |||
| """The tree node create fail error.""" | |||
| @@ -361,32 +311,15 @@ class TreeNodeInsertFail(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (OSError, | |||
| DuplicatedNodeIdError, | |||
| MultipleRootError, | |||
| NodeIDAbsentError) | |||
| NodeIDAbsentError, cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(arch, graph): | |||
| try: | |||
| output = func(arch, graph=graph) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class NodeInputTypeNotSupport(MindConverterException): | |||
| """The node input type NOT support error.""" | |||
| @@ -396,6 +329,10 @@ class NodeInputTypeNotSupport(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return ValueError, TypeError, IndexError, cls | |||
| class ScriptGenerateFail(MindConverterException): | |||
| """The script generate fail error.""" | |||
| @@ -405,31 +342,14 @@ class ScriptGenerateFail(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (RuntimeError, | |||
| parse.ParseError, | |||
| AttributeError) | |||
| AttributeError, cls) | |||
| return except_source | |||
| @classmethod | |||
| def check_except(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(arch, mapper): | |||
| try: | |||
| output = func(arch, mapper=mapper) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class ReportGenerateFail(MindConverterException): | |||
| """The report generate fail error.""" | |||
| @@ -439,28 +359,24 @@ class ReportGenerateFail(MindConverterException): | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| @staticmethod | |||
| def raise_from(): | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = ZeroDivisionError | |||
| return except_source | |||
| return ZeroDivisionError, cls | |||
| @classmethod | |||
| def check_except(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| def _f(arch, mapper): | |||
| try: | |||
| output = func(arch, mapper=mapper) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class SubGraphSearchingFail(MindConverterException): | |||
| """Sub-graph searching exception.""" | |||
| def __init__(self, msg): | |||
| super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||
| cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value, | |||
| user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Define exception in sub-graph searching module.""" | |||
| return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls | |||
| class GeneratorFail(MindConverterException): | |||
| @@ -470,10 +386,23 @@ class GeneratorFail(MindConverterException): | |||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (ValueError, | |||
| TypeError, | |||
| cls) | |||
| except_source = (ValueError, TypeError, cls) | |||
| return except_source | |||
| class ModelLoadingFail(MindConverterException): | |||
| """Model loading fail.""" | |||
| def __init__(self, msg): | |||
| super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR, | |||
| cls_code=ConverterErrors.MODEL_LOADING_FAIL.value, | |||
| user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Define exception when model loading fail.""" | |||
| return ValueError, cls | |||
| @@ -135,9 +135,11 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | |||
| if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | |||
| return False | |||
| return True | |||
| def get_dict_key_by_value(val, dic): | |||
| """ | |||
| Return the first appeared key of a dictionay by given value. | |||
| Return the first appeared key of a dictionary by given value. | |||
| Args: | |||
| val (Any): Value of the key. | |||
| @@ -16,6 +16,7 @@ | |||
| import os | |||
| import re | |||
| import argparse | |||
| import sys | |||
| from importlib import import_module | |||
| from importlib.util import find_spec | |||
| @@ -25,12 +26,11 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver | |||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | |||
| BaseConverterFail, UnknownModel | |||
| BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -71,8 +71,10 @@ def torch_installation_validation(func): | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| detail_info = f"Error detail: {str(error)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| sys.exit(0) | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder) | |||
| @@ -103,7 +105,10 @@ def tf_installation_validation(func): | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(str(error)) | |||
| raise error | |||
| detail_info = f"Error detail: {str(error)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| sys.exit(0) | |||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | |||
| ort = import_module("onnxruntime") | |||
| @@ -117,7 +122,10 @@ def tf_installation_validation(func): | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(str(error)) | |||
| raise error | |||
| detail_info = f"Error detail: {str(error)}" | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| sys.exit(0) | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder, | |||
| @@ -142,9 +150,10 @@ def _extract_model_name(model_path): | |||
| @torch_installation_validation | |||
| @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | |||
| @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | |||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| @@ -176,9 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| @tf_installation_validation | |||
| @GraphInitFail.check_except_tf("Error occurred when init graph object.") | |||
| @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | |||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||
| @TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.") | |||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| @@ -210,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| @BaseConverterFail.check_except("Failed to start base converter.") | |||
| @BaseConverterFail.uniform_catcher("Failed to start base converter.") | |||
| def main_graph_base_converter(file_config): | |||
| """ | |||
| The entrance for converter, script files will be converted. | |||
| @@ -201,6 +201,7 @@ class ArgsTranslation: | |||
| class ArgsTranslationHelper: | |||
| """Define operations related to ArgsTranslation instances.""" | |||
| @staticmethod | |||
| def find_formal_args_in_modules(args_translators): | |||
| """ | |||
| @@ -541,7 +541,7 @@ class ModuleStruct: | |||
| for output in output_list: | |||
| (provider_succ, provider_closet_opt_var) = output | |||
| if provider_closet_opt_var in struct.matched_inputs: | |||
| continue # skip repeat | |||
| continue # skip repeat | |||
| if provider_succ == struct.onnx_name: | |||
| struct.matched_inputs.append(provider_closet_opt_var) | |||
| @@ -695,20 +695,5 @@ class ModuleStruct: | |||
| """Register submodule outputs to this module's return.""" | |||
| submodule_returns = md_struct.get_returned_opt_var_name() | |||
| submodule_opt_var_name = md_struct.ms_opt_var_name | |||
| for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: | |||
| for (submodule_ext_succ, _, ith_output) in submodule_returns: | |||
| self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | |||
| # edit external succ 's inputs in parent module | |||
| ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) | |||
| ext_node_parent = ext_node.parent_module_struct | |||
| while ext_node_parent != self.parent_module_struct: | |||
| ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name | |||
| ext_node_parent = ext_node_parent.parent_module_struct | |||
| # need find the prec_name? | |||
| for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): | |||
| if isinstance(opt_var_name, str): | |||
| if opt_var_name == opt_var_name_in_this_module: | |||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||
| if isinstance(opt_var_name, tuple): | |||
| if opt_var_name[0] == opt_var_name_in_this_module: | |||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||
| @@ -34,6 +34,7 @@ class Pattern: | |||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | |||
| # the pattern will get additional score. | |||
| self.additional_score = 0 | |||
| self.know_module_name = None | |||
| def insert(self, idx, seq_len): | |||
| """ | |||
| @@ -18,7 +18,7 @@ import uuid | |||
| from typing import Dict, List, Callable, Union | |||
| from collections import OrderedDict | |||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | |||
| from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name | |||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||
| from .pattern import Pattern, scope_name_mapping | |||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | |||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | |||
| @@ -85,13 +85,15 @@ def _is_valid_pattern(pattern, dag): | |||
| return True | |||
| def generate_module_name(pattern): | |||
| def match_known_module_name(pattern): | |||
| """ | |||
| Generate module name. | |||
| Matching with know module name. | |||
| Args: | |||
| pattern (Pattern): To be replaced pattern. | |||
| Returns: | |||
| str, matched module name, return None if not matched. | |||
| """ | |||
| matched_result = [] | |||
| for ptn, module_name in BUILT_IN_MODULE_NAME.items(): | |||
| @@ -109,7 +111,11 @@ def generate_module_name(pattern): | |||
| module_name = f"{module_name}{used_module_name[pattern.pattern]}" | |||
| used_module_name[pattern.pattern] += 1 | |||
| return module_name | |||
| return None | |||
| def generate_module_name(): | |||
| """Generate module name.""" | |||
| global global_idx | |||
| name = f"Module{global_idx}" | |||
| global_idx += 1 | |||
| @@ -439,13 +445,16 @@ class SearchPath: | |||
| to recover the sequence. | |||
| """ | |||
| if self.pattern.pattern not in scope_name_mapping: | |||
| module_name = generate_module_name(self.pattern) | |||
| scope_name_mapping[self.pattern.pattern] = module_name | |||
| module_name = generate_module_name() | |||
| known_module_name = match_known_module_name(self.pattern) | |||
| scope_name_mapping[self.pattern] = module_name | |||
| module_name_to_src[module_name] = self.pattern.pattern | |||
| else: | |||
| module_name = scope_name_mapping[self.pattern.pattern] | |||
| known_module_name = module_name_to_src[module_name].known_module_name | |||
| self.pattern.module_name = module_name | |||
| if is_built_in_module_name(module_name): | |||
| self.pattern.known_module_name = known_module_name | |||
| if known_module_name: | |||
| self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) | |||
| topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | |||
| return topo_order, inverted_index | |||
| @@ -18,8 +18,10 @@ from typing import Dict, List | |||
| from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | |||
| from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||
| from ..common.global_context import GlobalContext | |||
| from ..third_party_graph.onnx_utils import BaseNode | |||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | |||
| from ...common.exceptions import SubGraphSearchingFail | |||
| def _is_satisfied(path): | |||
| @@ -249,6 +251,23 @@ def validate_topo_order_succession(): | |||
| return True | |||
| def _add_known_module_name(search_path): | |||
| """ | |||
| Add known module name to GlobalContext. | |||
| Args: | |||
| search_path (SearchPath): Search path. | |||
| """ | |||
| ctx = GlobalContext() | |||
| if search_path.pattern.known_module_name: | |||
| ctx.known_module_name[search_path.pattern.module_name] = search_path.pattern.known_module_name | |||
| for it in search_path.recursion_path: | |||
| if it.pattern.known_module_name: | |||
| ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name | |||
| @SubGraphSearchingFail.check_except("Sub-Graph searching fail.") | |||
| def generate_scope_name(data_loader): | |||
| """ | |||
| Generate scope name according to computation graph. | |||
| @@ -270,6 +289,9 @@ def generate_scope_name(data_loader): | |||
| if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): | |||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | |||
| if result: | |||
| _add_known_module_name(result) | |||
| except (ValueError, IndexError, AttributeError, KeyError) as _: | |||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | |||
| return topo_order_with_scope_name_list | |||
| @@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext | |||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| from ...common.exceptions import GraphInitFail, ModelNotSupport | |||
| from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail | |||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | |||
| @@ -308,7 +308,7 @@ class OnnxDataLoader: | |||
| w = int(match.group('w')) | |||
| c = int(match.group('c')) | |||
| if [h, w, c] != list(self.graph_input_shape)[1:4]: | |||
| raise ValueError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||
| raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||
| return True | |||
| return False | |||
| @@ -25,7 +25,9 @@ class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") | |||
| @ModelNotSupport.check_except( | |||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||
| ) | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parser pytorch graph. | |||
| @@ -50,11 +52,9 @@ class PyTorchGraphParser(GraphParser): | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| except ModuleNotFoundError: | |||
| error_msg = \ | |||
| "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error_msg = "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| log.error(str(error)) | |||
| raise error from None | |||
| raise error | |||
| return model | |||
| @@ -25,7 +25,9 @@ class TFGraphParser(GraphParser): | |||
| """Define TF graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") | |||
| @ModelNotSupport.check_except( | |||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||
| ) | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parse TF Computational Graph File (.pb) | |||
| @@ -36,7 +38,6 @@ class TFGraphParser(GraphParser): | |||
| Returns: | |||
| object, ONNX model. | |||
| """ | |||
| onnx_utils = import_module( | |||
| "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") | |||
| convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") | |||
| @@ -50,6 +51,5 @@ class TFGraphParser(GraphParser): | |||
| model = convert_tf_graph_to_onnx(model_path, | |||
| model_inputs=tf_input_nodes, | |||
| model_outputs=tf_output_nodes, | |||
| ) | |||
| model_outputs=tf_output_nodes) | |||
| return model | |||
| @@ -21,13 +21,10 @@ Usage: | |||
| """ | |||
| import difflib | |||
| import os | |||
| import re | |||
| import sys | |||
| import pytest | |||
| from mindinsight.mindconverter.converter import main | |||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||
| @pytest.mark.usefixtures('create_output_dir') | |||
| @@ -82,35 +79,3 @@ class TestConverter: | |||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | |||
| assert converted_ratio >= 80 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_main_graph_based_converter(self, output): | |||
| """Test main graph based converter.""" | |||
| pytorch_filename = "resnet50.pth" | |||
| expected_model_filename = "resnet50.py" | |||
| expected_report_filename = "report_of_resnet50.txt" | |||
| file_config = { | |||
| 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), | |||
| 'shape': (1, 3, 224, 224), | |||
| 'outfile_dir': output, | |||
| 'report_dir': output | |||
| } | |||
| with pytest.raises(ValueError) as e: | |||
| main_graph_base_converter(file_config=file_config) | |||
| assert os.path.isfile(os.path.join(output, expected_model_filename)) | |||
| assert os.path.isfile(os.path.join(output, expected_report_filename)) | |||
| with open(os.path.join(output, expected_report_filename)) as converted_r: | |||
| converted_report = converted_r.readlines() | |||
| converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) | |||
| assert converted_rate[0] == '100.00%' | |||
| exec_msg = e.value.args[0] | |||
| assert exec_msg == "torch.__spec__ is None" | |||