From: @liuchongming74 Reviewed-by: @lilongfei15,@wenkai_dist,@lilongfei15 Signed-off-by: @lilongfei15tags/v1.1.0
| @@ -15,7 +15,7 @@ | |||
| """Define custom exception.""" | |||
| import abc | |||
| import sys | |||
| from enum import unique | |||
| from enum import unique, Enum | |||
| from importlib import import_module | |||
| from lib2to3.pgen2 import parse | |||
| @@ -33,16 +33,6 @@ class ConverterErrors(ScriptConverterErrors): | |||
| SCRIPT_NOT_SUPPORT = 1 | |||
| NODE_TYPE_NOT_SUPPORT = 2 | |||
| CODE_SYNTAX_ERROR = 3 | |||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | |||
| NODE_INPUT_MISSING = 5 | |||
| TREE_NODE_INSERT_FAIL = 6 | |||
| UNKNOWN_MODEL = 7 | |||
| MODEL_NOT_SUPPORT = 8 | |||
| SCRIPT_GENERATE_FAIL = 9 | |||
| REPORT_GENERATE_FAIL = 10 | |||
| NODE_CONVERSION_ERROR = 11 | |||
| INPUT_SHAPE_ERROR = 12 | |||
| TF_RUNTIME_ERROR = 13 | |||
| BASE_CONVERTER_FAIL = 000 | |||
| GRAPH_INIT_FAIL = 100 | |||
| @@ -50,7 +40,6 @@ class ConverterErrors(ScriptConverterErrors): | |||
| SOURCE_FILES_SAVE_FAIL = 300 | |||
| GENERATOR_FAIL = 400 | |||
| SUB_GRAPH_SEARCHING_FAIL = 500 | |||
| MODEL_LOADING_FAIL = 600 | |||
| class ScriptNotSupport(MindInsightException): | |||
| @@ -82,21 +71,20 @@ class CodeSyntaxError(MindInsightException): | |||
| class MindConverterException(Exception): | |||
| """MindConverter exception.""" | |||
| BASE_ERROR_CODE = None # ConverterErrors.BASE_CONVERTER_FAIL.value | |||
| # ERROR_CODE should be declared in child exception. | |||
| ERROR_CODE = None | |||
| def __init__(self, **kwargs): | |||
| """Initialization of MindInsightException.""" | |||
| error = kwargs.get('error', None) | |||
| user_msg = kwargs.get('user_msg', '') | |||
| debug_msg = kwargs.get('debug_msg', '') | |||
| cls_code = kwargs.get('cls_code', 0) | |||
| if isinstance(user_msg, str): | |||
| user_msg = ' '.join(user_msg.split()) | |||
| super(MindConverterException, self).__init__() | |||
| self.error = error | |||
| self.user_msg = user_msg | |||
| self.debug_msg = debug_msg | |||
| self.cls_code = cls_code | |||
| self.root_exception_error_code = None | |||
| def __str__(self): | |||
| return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | |||
| @@ -116,8 +104,12 @@ class MindConverterException(Exception): | |||
| Returns: | |||
| str, Hex string representing the composed MindConverter error code. | |||
| """ | |||
| num = 0xFFFF & self.error.value | |||
| error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | |||
| if self.root_exception_error_code: | |||
| return self.root_exception_error_code | |||
| if self.BASE_ERROR_CODE is None or self.ERROR_CODE is None: | |||
| raise ValueError("MindConverterException has not been initialized.") | |||
| num = 0xFFFF & self.ERROR_CODE # 0xFFFF & self.error.value | |||
| error_code = f"{str(self.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}" | |||
| return error_code | |||
| @classmethod | |||
| @@ -126,7 +118,7 @@ class MindConverterException(Exception): | |||
| """Raise from below exceptions.""" | |||
| @classmethod | |||
| def uniform_catcher(cls, msg): | |||
| def uniform_catcher(cls, msg: str = ""): | |||
| """Uniform exception catcher.""" | |||
| def decorator(func): | |||
| @@ -134,16 +126,20 @@ class MindConverterException(Exception): | |||
| try: | |||
| res = func(*args, **kwargs) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| detail_info = f"Error detail: {str(e)}" | |||
| log_console.error(str(error)) | |||
| error = cls() if not msg else cls(msg=msg) | |||
| detail_info = str(e) | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| log_console.error(detail_info) | |||
| log_console.error("\n") | |||
| log.exception(e) | |||
| sys.exit(0) | |||
| except ModuleNotFoundError as e: | |||
| detail_info = f"Error detail: Required package not found, please check the runtime environment." | |||
| detail_info = "Error detail: Required package not found, please check the runtime environment." | |||
| log_console.error("\n") | |||
| log_console.error(str(e)) | |||
| log_console.error(detail_info) | |||
| log_console.error("\n") | |||
| log.exception(e) | |||
| sys.exit(0) | |||
| return res | |||
| @@ -161,9 +157,12 @@ class MindConverterException(Exception): | |||
| try: | |||
| output = func(*args, **kwargs) | |||
| except cls.raise_from() as e: | |||
| error = cls(msg=msg) | |||
| error_code = e.error_code() if isinstance(e, MindConverterException) else None | |||
| error.root_exception_error_code = error_code | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise cls(msg=msg) | |||
| raise error | |||
| except Exception as e: | |||
| log.error(msg) | |||
| log.exception(e) | |||
| @@ -175,12 +174,21 @@ class MindConverterException(Exception): | |||
| return decorator | |||
| class BaseConverterFail(MindConverterException): | |||
| class BaseConverterError(MindConverterException): | |||
| """Base converter failed.""" | |||
| def __init__(self, msg): | |||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | |||
| user_msg=msg) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of BaseConverterError.""" | |||
| UNKNOWN_ERROR = 0 | |||
| UNKNOWN_MODEL = 1 | |||
| BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| DEFAULT_MSG = "Failed to start base converter." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(BaseConverterError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -189,32 +197,45 @@ class BaseConverterFail(MindConverterException): | |||
| return except_source | |||
| class UnknownModel(MindConverterException): | |||
| class UnknownModelError(BaseConverterError): | |||
| """The unknown model error.""" | |||
| ERROR_CODE = BaseConverterError.ErrCode.UNKNOWN_MODEL.value | |||
| def __init__(self, msg): | |||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | |||
| user_msg=msg) | |||
| super(UnknownModelError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return cls | |||
| class GraphInitFail(MindConverterException): | |||
| class GraphInitError(MindConverterException): | |||
| """The graph init fail error.""" | |||
| def __init__(self, **kwargs): | |||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | |||
| user_msg=kwargs.get('msg', '')) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of GraphInitError.""" | |||
| UNKNOWN_ERROR = 0 | |||
| MODEL_NOT_SUPPORT = 1 | |||
| TF_RUNTIME_ERROR = 2 | |||
| INPUT_SHAPE_ERROR = 3 | |||
| MI_RUNTIME_ERROR = 4 | |||
| BASE_ERROR_CODE = ConverterErrors.GRAPH_INIT_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| DEFAULT_MSG = "Error occurred when init graph object." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(GraphInitError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (FileNotFoundError, | |||
| ModuleNotFoundError, | |||
| ModelNotSupport, | |||
| SubGraphSearchingFail, | |||
| ModelNotSupportError, | |||
| ModelLoadingError, | |||
| RuntimeIntegrityError, | |||
| TypeError, | |||
| ZeroDivisionError, | |||
| RuntimeError, | |||
| @@ -222,45 +243,65 @@ class GraphInitFail(MindConverterException): | |||
| return except_source | |||
| class TreeCreateFail(MindConverterException): | |||
| class TreeCreationError(MindConverterException): | |||
| """The tree create fail.""" | |||
| def __init__(self, msg): | |||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | |||
| user_msg=msg) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of TreeCreationError.""" | |||
| UNKNOWN_ERROR = 0 | |||
| NODE_INPUT_MISSING = 1 | |||
| TREE_NODE_INSERT_FAIL = 2 | |||
| BASE_ERROR_CODE = ConverterErrors.TREE_CREATE_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| DEFAULT_MSG = "Error occurred when create hierarchical tree." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(TreeCreationError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (NodeInputMissing, | |||
| TreeNodeInsertFail, cls) | |||
| except_source = NodeInputMissingError, TreeNodeInsertError, cls | |||
| return except_source | |||
| class SourceFilesSaveFail(MindConverterException): | |||
| class SourceFilesSaveError(MindConverterException): | |||
| """The source files save fail error.""" | |||
| def __init__(self, msg): | |||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | |||
| user_msg=msg) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of SourceFilesSaveError.""" | |||
| UNKNOWN_ERROR = 0 | |||
| NODE_INPUT_TYPE_NOT_SUPPORT = 1 | |||
| SCRIPT_GENERATE_FAIL = 2 | |||
| REPORT_GENERATE_FAIL = 3 | |||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| DEFAULT_MSG = "Error occurred when save source files." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(SourceFilesSaveError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (NodeInputTypeNotSupport, | |||
| ScriptGenerateFail, | |||
| ReportGenerateFail, | |||
| except_source = (NodeInputTypeNotSupportError, | |||
| ScriptGenerationError, | |||
| ReportGenerationError, | |||
| IOError, cls) | |||
| return except_source | |||
| class ModelNotSupport(MindConverterException): | |||
| class ModelNotSupportError(GraphInitError): | |||
| """The model not support error.""" | |||
| ERROR_CODE = GraphInitError.ErrCode.MODEL_NOT_SUPPORT.value | |||
| def __init__(self, msg): | |||
| super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||
| super(ModelNotSupportError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -275,13 +316,13 @@ class ModelNotSupport(MindConverterException): | |||
| return except_source | |||
| class TfRuntimeError(MindConverterException): | |||
| class TfRuntimeError(GraphInitError): | |||
| """Catch tf runtime error.""" | |||
| ERROR_CODE = GraphInitError.ErrCode.TF_RUNTIME_ERROR.value | |||
| DEFAULT_MSG = "Error occurred when init graph, TensorFlow runtime error." | |||
| def __init__(self, msg): | |||
| super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(TfRuntimeError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -290,26 +331,36 @@ class TfRuntimeError(MindConverterException): | |||
| return tf_error, ValueError, RuntimeError, cls | |||
| class NodeInputMissing(MindConverterException): | |||
| class RuntimeIntegrityError(GraphInitError): | |||
| """Catch runtime error.""" | |||
| ERROR_CODE = GraphInitError.ErrCode.MI_RUNTIME_ERROR.value | |||
| def __init__(self, msg): | |||
| super(RuntimeIntegrityError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls | |||
| class NodeInputMissingError(TreeCreationError): | |||
| """The node input missing error.""" | |||
| ERROR_CODE = TreeCreationError.ErrCode.NODE_INPUT_MISSING.value | |||
| def __init__(self, msg): | |||
| super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||
| super(NodeInputMissingError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return ValueError, IndexError, KeyError, AttributeError, cls | |||
| class TreeNodeInsertFail(MindConverterException): | |||
| class TreeNodeInsertError(TreeCreationError): | |||
| """The tree node create fail error.""" | |||
| ERROR_CODE = TreeCreationError.ErrCode.TREE_NODE_INSERT_FAIL.value | |||
| def __init__(self, msg): | |||
| super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||
| super(TreeNodeInsertError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -321,26 +372,24 @@ class TreeNodeInsertFail(MindConverterException): | |||
| return except_source | |||
| class NodeInputTypeNotSupport(MindConverterException): | |||
| class NodeInputTypeNotSupportError(SourceFilesSaveError): | |||
| """The node input type NOT support error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value | |||
| def __init__(self, msg): | |||
| super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| super(NodeInputTypeNotSupportError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| return ValueError, TypeError, IndexError, cls | |||
| class ScriptGenerateFail(MindConverterException): | |||
| class ScriptGenerationError(SourceFilesSaveError): | |||
| """The script generate fail error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value | |||
| def __init__(self, msg): | |||
| super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| super(ScriptGenerationError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -351,13 +400,12 @@ class ScriptGenerateFail(MindConverterException): | |||
| return except_source | |||
| class ReportGenerateFail(MindConverterException): | |||
| class ReportGenerationError(SourceFilesSaveError): | |||
| """The report generate fail error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value | |||
| def __init__(self, msg): | |||
| super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||
| super(ReportGenerationError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -365,13 +413,22 @@ class ReportGenerateFail(MindConverterException): | |||
| return ZeroDivisionError, cls | |||
| class SubGraphSearchingFail(MindConverterException): | |||
| class SubGraphSearchingError(MindConverterException): | |||
| """Sub-graph searching exception.""" | |||
| def __init__(self, msg): | |||
| super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||
| cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value, | |||
| user_msg=msg) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of SourceFilesSaveError.""" | |||
| BASE_ERROR = 0 | |||
| CANNOT_FIND_VALID_PATTERN = 1 | |||
| MODEL_NOT_SUPPORT = 2 | |||
| BASE_ERROR_CODE = ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value | |||
| ERROR_CODE = ErrCode.BASE_ERROR.value | |||
| DEFAULT_MSG = "Sub-Graph pattern searching fail." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(SubGraphSearchingError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -379,13 +436,22 @@ class SubGraphSearchingFail(MindConverterException): | |||
| return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls | |||
| class GeneratorFail(MindConverterException): | |||
| class GeneratorError(MindConverterException): | |||
| """The Generator fail error.""" | |||
| def __init__(self, msg): | |||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | |||
| @unique | |||
| class ErrCode(Enum): | |||
| """Define error code of SourceFilesSaveError.""" | |||
| BASE_ERROR = 0 | |||
| STATEMENT_GENERATION_ERROR = 1 | |||
| CONVERTED_OPERATOR_LOADING_ERROR = 2 | |||
| BASE_ERROR_CODE = ConverterErrors.GENERATOR_FAIL.value | |||
| ERROR_CODE = ErrCode.BASE_ERROR.value | |||
| DEFAULT_MSG = "Error occurred when generate code." | |||
| def __init__(self, msg=DEFAULT_MSG): | |||
| super(GeneratorError, self).__init__(user_msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -394,13 +460,12 @@ class GeneratorFail(MindConverterException): | |||
| return except_source | |||
| class ModelLoadingFail(MindConverterException): | |||
| class ModelLoadingError(GraphInitError): | |||
| """Model loading fail.""" | |||
| ERROR_CODE = GraphInitError.ErrCode.INPUT_SHAPE_ERROR.value | |||
| def __init__(self, msg): | |||
| super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR, | |||
| cls_code=ConverterErrors.MODEL_LOADING_FAIL.value, | |||
| user_msg=msg) | |||
| super(ModelLoadingError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| @@ -87,7 +87,7 @@ def get_imported_module(): | |||
| str, imported module. | |||
| """ | |||
| return f"import numpy as np{NEW_LINE}" \ | |||
| f"import mindspore{NEW_LINE}" \ | |||
| f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||
| f"import mindspore{NEW_LINE}" \ | |||
| f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||
| @@ -27,8 +27,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEAD | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | |||
| BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | |||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| @@ -67,13 +67,13 @@ def torch_installation_validation(func): | |||
| output_folder: str, report_folder: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch"): | |||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| log.error(str(error)) | |||
| detail_info = f"Error detail: {str(error)}" | |||
| error = RuntimeIntegrityError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log_console.error("\n") | |||
| sys.exit(0) | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| @@ -97,17 +97,17 @@ def tf_installation_validation(func): | |||
| output_folder: str, report_folder: str = None, | |||
| input_nodes: str = None, output_nodes: str = None): | |||
| # Check whether tensorflow is installed. | |||
| if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \ | |||
| or not find_spec("onnxruntime"): | |||
| error = ModuleNotFoundError( | |||
| if not find_spec("tensorflow") or not find_spec("tensorflow-gpu") or not find_spec("tf2onnx") \ | |||
| or not find_spec("onnx") or not find_spec("onnxruntime"): | |||
| error = RuntimeIntegrityError( | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(str(error)) | |||
| detail_info = f"Error detail: {str(error)}" | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log_console.error("\n") | |||
| sys.exit(0) | |||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | |||
| @@ -116,15 +116,15 @@ def tf_installation_validation(func): | |||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): | |||
| error = ModuleNotFoundError( | |||
| error = RuntimeIntegrityError( | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(str(error)) | |||
| detail_info = f"Error detail: {str(error)}" | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| log_console.error(str(error)) | |||
| log_console.error(detail_info) | |||
| log_console.error("\n") | |||
| sys.exit(0) | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| @@ -150,14 +150,14 @@ def _extract_model_name(model_path): | |||
| @torch_installation_validation | |||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||
| @GraphInitError.uniform_catcher() | |||
| @TreeCreationError.uniform_catcher() | |||
| @SourceFilesSaveError.uniform_catcher() | |||
| @GeneratorError.uniform_catcher() | |||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| Pytoch to MindSpore based on Graph. | |||
| PyTorch to MindSpore based on Graph. | |||
| Args: | |||
| graph_path (str): Graph file path. | |||
| @@ -185,11 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| @tf_installation_validation | |||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||
| @TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.") | |||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||
| @GraphInitError.uniform_catcher() | |||
| @TfRuntimeError.uniform_catcher() | |||
| @TreeCreationError.uniform_catcher() | |||
| @SourceFilesSaveError.uniform_catcher() | |||
| @GeneratorError.uniform_catcher() | |||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| @@ -221,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| @BaseConverterFail.uniform_catcher("Failed to start base converter.") | |||
| @BaseConverterError.uniform_catcher() | |||
| def main_graph_base_converter(file_config): | |||
| """ | |||
| The entrance for converter, script files will be converted. | |||
| @@ -248,7 +248,7 @@ def main_graph_base_converter(file_config): | |||
| report_folder=file_config['report_dir']) | |||
| else: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModel(error_msg) | |||
| error = UnknownModelError(error_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| @@ -263,7 +263,7 @@ def get_framework_type(model_path): | |||
| framework_type = FrameworkType.TENSORFLOW.value | |||
| except IOError: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModel(error_msg) | |||
| error = UnknownModelError(error_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| @@ -23,7 +23,7 @@ from .node_struct import NodeStruct | |||
| from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ...common.exceptions import GeneratorFail | |||
| from ...common.exceptions import GeneratorError | |||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||
| from ..report_generator import ReportGenerator | |||
| @@ -171,7 +171,7 @@ class Generator: | |||
| self._global_context.node_struct_collections = self._node_struct_collections | |||
| self._repeated_submodules = set() | |||
| @GeneratorFail.check_except("Generator occurs an error when forming base submodules.") | |||
| @GeneratorError.check_except("Generator occurs an error when forming base submodules.") | |||
| def _form_bottom_submodule(self): | |||
| """Form the basic submodules, which only contains nodes.""" | |||
| # Form module map | |||
| @@ -351,7 +351,7 @@ class Generator: | |||
| self._global_context.add_module_struct(sub.pattern_id, sub) | |||
| depth -= 1 | |||
| @GeneratorFail.check_except("Generator occurs an error when building modules.") | |||
| @GeneratorError.check_except("Generator occurs an error when building modules.") | |||
| def _recursive_form_module(self): | |||
| """Main routine in generator to build modules from bottom to top.""" | |||
| # 1. List repeated submodules | |||
| @@ -474,7 +474,7 @@ class Generator: | |||
| """Return all ModuleStructs in this model.""" | |||
| return self._module_struct_collections | |||
| @GeneratorFail.check_except("Generator occurs an error when generating code statements.") | |||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | |||
| def generate(self): | |||
| """ | |||
| Generate the final script file. | |||
| @@ -22,7 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import InputType | |||
| from ...common.exceptions import GeneratorFail | |||
| from ...common.exceptions import GeneratorError | |||
| class NodeStruct: | |||
| @@ -146,7 +146,7 @@ class NodeStruct: | |||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | |||
| self.scope = Scope(parsed_scope) | |||
| @GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.") | |||
| @GeneratorError.check_except("Generator occurs an error when initializing node's args translator.") | |||
| def init_args_translator(self, translated_args: list): | |||
| """ | |||
| Initialize the ArgsTranslator for each Node. | |||
| @@ -170,7 +170,7 @@ class NodeStruct: | |||
| self.ms_op]): | |||
| self.ready_to_generate = True | |||
| @GeneratorFail.check_except("Generator occurs an error when creating node struct.") | |||
| @GeneratorError.check_except("Generator occurs an error when creating node struct.") | |||
| def update(self, arg, force_ready=False): | |||
| """ | |||
| Pass Node info. to generator NodeStruct. | |||
| @@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from .hierarchical_tree import HierarchicalTree | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail | |||
| from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError | |||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||
| @@ -53,7 +53,7 @@ class HierarchicalTreeFactory: | |||
| """Hierarchical tree factory.""" | |||
| @classmethod | |||
| @TreeNodeInsertFail.check_except("Tree node inserts failed.") | |||
| @TreeNodeInsertError.check_except("Tree node inserts failed.") | |||
| def create(cls, graph): | |||
| """ | |||
| Factory method of hierarchical tree. | |||
| @@ -73,7 +73,7 @@ class HierarchicalTreeFactory: | |||
| if node_input != 0 and not node_input: | |||
| err_msg = f"This model is not supported now. " \ | |||
| f"Cannot find {node_name}'s input shape." | |||
| error = NodeInputMissing(err_msg) | |||
| error = NodeInputMissingError(err_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| if isinstance(node_inst, OnnxGraphNode): | |||
| @@ -35,7 +35,7 @@ from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| from ..report_generator import ReportGenerator | |||
| from ...common.exceptions import ReportGenerateFail, ScriptGenerateFail, NodeInputTypeNotSupport | |||
| from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError | |||
| class HierarchicalTree(Tree): | |||
| @@ -189,7 +189,7 @@ class HierarchicalTree(Tree): | |||
| try: | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e: | |||
| except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e: | |||
| log.error("Error occur when generating codes.") | |||
| raise e | |||
| @@ -264,8 +264,8 @@ class HierarchicalTree(Tree): | |||
| node.data.args_in_code.pop(arg) | |||
| return node | |||
| @ScriptGenerateFail.check_except("FormatCode run error. Check detailed information in log.") | |||
| @ReportGenerateFail.check_except("Not find valid operators in converted script.") | |||
| @ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.") | |||
| @ReportGenerationError.check_except("Not find valid operators in converted script.") | |||
| def _generate_codes(self, mapper): | |||
| """ | |||
| Generate code files. | |||
| @@ -21,7 +21,7 @@ from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||
| from ..common.global_context import GlobalContext | |||
| from ..third_party_graph.onnx_utils import BaseNode | |||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | |||
| from ...common.exceptions import SubGraphSearchingFail | |||
| from ...common.exceptions import SubGraphSearchingError | |||
| def _is_satisfied(path): | |||
| @@ -267,7 +267,7 @@ def _add_known_module_name(search_path): | |||
| ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name | |||
| @SubGraphSearchingFail.check_except("Sub-Graph searching fail.") | |||
| @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") | |||
| def generate_scope_name(data_loader): | |||
| """ | |||
| Generate scope name according to computation graph. | |||
| @@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..constant import NodeType, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| from ...common.exceptions import NodeInputTypeNotSupportError | |||
| class GraphParser(metaclass=abc.ABCMeta): | |||
| @@ -522,7 +522,7 @@ class GraphNode(abc.ABC): | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct},)" | |||
| else: | |||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||
| raise NodeInputTypeNotSupportError(f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| @@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext | |||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail | |||
| from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError | |||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | |||
| @@ -308,7 +308,7 @@ class OnnxDataLoader: | |||
| w = int(match.group('w')) | |||
| c = int(match.group('c')) | |||
| if [h, w, c] != list(self.graph_input_shape)[1:4]: | |||
| raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||
| raise ModelLoadingError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||
| return True | |||
| return False | |||
| @@ -383,7 +383,7 @@ class OnnxDataLoader: | |||
| self._nodes_dict[n.name] = n | |||
| nodes_topo_idx.append((idx, n.name)) | |||
| if len(node.output) > 1: | |||
| raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") | |||
| raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | |||
| self.output_name_to_node_name[node.output[0]] = node.name | |||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||
| node_inputs = [i.replace(":0", "") for i in node.input] | |||
| @@ -423,7 +423,7 @@ class OnnxDataLoader: | |||
| shape[i] = int(shape[i]) | |||
| node_name = self.output_name_to_node_name[node_opt_name] | |||
| if not node_name: | |||
| raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.") | |||
| raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") | |||
| self.node_output_shape_dict[node_name] = shape | |||
| def get_node(self, node_name): | |||
| @@ -18,14 +18,14 @@ from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from ...common.exceptions import ModelNotSupport | |||
| from ...common.exceptions import ModelNotSupportError | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except( | |||
| @ModelNotSupportError.check_except( | |||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||
| ) | |||
| def parse(cls, model_path: str, **kwargs): | |||
| @@ -18,14 +18,14 @@ from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from ...common.exceptions import ModelNotSupport | |||
| from ...common.exceptions import ModelNotSupportError | |||
| class TFGraphParser(GraphParser): | |||
| """Define TF graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except( | |||
| @ModelNotSupportError.check_except( | |||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||
| ) | |||
| def parse(cls, model_path: str, **kwargs): | |||