From: @liuchongming74 Reviewed-by: @lilongfei15,@wenkai_dist,@lilongfei15 Signed-off-by: @lilongfei15tags/v1.1.0
| @@ -15,7 +15,7 @@ | |||||
| """Define custom exception.""" | """Define custom exception.""" | ||||
| import abc | import abc | ||||
| import sys | import sys | ||||
| from enum import unique | |||||
| from enum import unique, Enum | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from lib2to3.pgen2 import parse | from lib2to3.pgen2 import parse | ||||
| @@ -33,16 +33,6 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| SCRIPT_NOT_SUPPORT = 1 | SCRIPT_NOT_SUPPORT = 1 | ||||
| NODE_TYPE_NOT_SUPPORT = 2 | NODE_TYPE_NOT_SUPPORT = 2 | ||||
| CODE_SYNTAX_ERROR = 3 | CODE_SYNTAX_ERROR = 3 | ||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | |||||
| NODE_INPUT_MISSING = 5 | |||||
| TREE_NODE_INSERT_FAIL = 6 | |||||
| UNKNOWN_MODEL = 7 | |||||
| MODEL_NOT_SUPPORT = 8 | |||||
| SCRIPT_GENERATE_FAIL = 9 | |||||
| REPORT_GENERATE_FAIL = 10 | |||||
| NODE_CONVERSION_ERROR = 11 | |||||
| INPUT_SHAPE_ERROR = 12 | |||||
| TF_RUNTIME_ERROR = 13 | |||||
| BASE_CONVERTER_FAIL = 000 | BASE_CONVERTER_FAIL = 000 | ||||
| GRAPH_INIT_FAIL = 100 | GRAPH_INIT_FAIL = 100 | ||||
| @@ -50,7 +40,6 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| SOURCE_FILES_SAVE_FAIL = 300 | SOURCE_FILES_SAVE_FAIL = 300 | ||||
| GENERATOR_FAIL = 400 | GENERATOR_FAIL = 400 | ||||
| SUB_GRAPH_SEARCHING_FAIL = 500 | SUB_GRAPH_SEARCHING_FAIL = 500 | ||||
| MODEL_LOADING_FAIL = 600 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -82,21 +71,20 @@ class CodeSyntaxError(MindInsightException): | |||||
| class MindConverterException(Exception): | class MindConverterException(Exception): | ||||
| """MindConverter exception.""" | """MindConverter exception.""" | ||||
| BASE_ERROR_CODE = None # ConverterErrors.BASE_CONVERTER_FAIL.value | |||||
| # ERROR_CODE should be declared in child exception. | |||||
| ERROR_CODE = None | |||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| """Initialization of MindInsightException.""" | """Initialization of MindInsightException.""" | ||||
| error = kwargs.get('error', None) | |||||
| user_msg = kwargs.get('user_msg', '') | user_msg = kwargs.get('user_msg', '') | ||||
| debug_msg = kwargs.get('debug_msg', '') | |||||
| cls_code = kwargs.get('cls_code', 0) | |||||
| if isinstance(user_msg, str): | if isinstance(user_msg, str): | ||||
| user_msg = ' '.join(user_msg.split()) | user_msg = ' '.join(user_msg.split()) | ||||
| super(MindConverterException, self).__init__() | super(MindConverterException, self).__init__() | ||||
| self.error = error | |||||
| self.user_msg = user_msg | self.user_msg = user_msg | ||||
| self.debug_msg = debug_msg | |||||
| self.cls_code = cls_code | |||||
| self.root_exception_error_code = None | |||||
| def __str__(self): | def __str__(self): | ||||
| return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | ||||
| @@ -116,8 +104,12 @@ class MindConverterException(Exception): | |||||
| Returns: | Returns: | ||||
| str, Hex string representing the composed MindConverter error code. | str, Hex string representing the composed MindConverter error code. | ||||
| """ | """ | ||||
| num = 0xFFFF & self.error.value | |||||
| error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | |||||
| if self.root_exception_error_code: | |||||
| return self.root_exception_error_code | |||||
| if self.BASE_ERROR_CODE is None or self.ERROR_CODE is None: | |||||
| raise ValueError("MindConverterException has not been initialized.") | |||||
| num = 0xFFFF & self.ERROR_CODE # 0xFFFF & self.error.value | |||||
| error_code = f"{str(self.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}" | |||||
| return error_code | return error_code | ||||
| @classmethod | @classmethod | ||||
| @@ -126,7 +118,7 @@ class MindConverterException(Exception): | |||||
| """Raise from below exceptions.""" | """Raise from below exceptions.""" | ||||
| @classmethod | @classmethod | ||||
| def uniform_catcher(cls, msg): | |||||
| def uniform_catcher(cls, msg: str = ""): | |||||
| """Uniform exception catcher.""" | """Uniform exception catcher.""" | ||||
| def decorator(func): | def decorator(func): | ||||
| @@ -134,16 +126,20 @@ class MindConverterException(Exception): | |||||
| try: | try: | ||||
| res = func(*args, **kwargs) | res = func(*args, **kwargs) | ||||
| except cls.raise_from() as e: | except cls.raise_from() as e: | ||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| error = cls() if not msg else cls(msg=msg) | |||||
| detail_info = str(e) | |||||
| log.error(error) | |||||
| log_console.error("\n") | |||||
| log_console.error(detail_info) | log_console.error(detail_info) | ||||
| log_console.error("\n") | |||||
| log.exception(e) | log.exception(e) | ||||
| sys.exit(0) | sys.exit(0) | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| detail_info = f"Error detail: Required package not found, please check the runtime environment." | |||||
| detail_info = "Error detail: Required package not found, please check the runtime environment." | |||||
| log_console.error("\n") | |||||
| log_console.error(str(e)) | log_console.error(str(e)) | ||||
| log_console.error(detail_info) | log_console.error(detail_info) | ||||
| log_console.error("\n") | |||||
| log.exception(e) | log.exception(e) | ||||
| sys.exit(0) | sys.exit(0) | ||||
| return res | return res | ||||
| @@ -161,9 +157,12 @@ class MindConverterException(Exception): | |||||
| try: | try: | ||||
| output = func(*args, **kwargs) | output = func(*args, **kwargs) | ||||
| except cls.raise_from() as e: | except cls.raise_from() as e: | ||||
| error = cls(msg=msg) | |||||
| error_code = e.error_code() if isinstance(e, MindConverterException) else None | |||||
| error.root_exception_error_code = error_code | |||||
| log.error(msg) | log.error(msg) | ||||
| log.exception(e) | log.exception(e) | ||||
| raise cls(msg=msg) | |||||
| raise error | |||||
| except Exception as e: | except Exception as e: | ||||
| log.error(msg) | log.error(msg) | ||||
| log.exception(e) | log.exception(e) | ||||
| @@ -175,12 +174,21 @@ class MindConverterException(Exception): | |||||
| return decorator | return decorator | ||||
| class BaseConverterFail(MindConverterException): | |||||
| class BaseConverterError(MindConverterException): | |||||
| """Base converter failed.""" | """Base converter failed.""" | ||||
| def __init__(self, msg): | |||||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | |||||
| user_msg=msg) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of BaseConverterError.""" | |||||
| UNKNOWN_ERROR = 0 | |||||
| UNKNOWN_MODEL = 1 | |||||
| BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value | |||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||||
| DEFAULT_MSG = "Failed to start base converter." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(BaseConverterError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -189,32 +197,45 @@ class BaseConverterFail(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class UnknownModel(MindConverterException): | |||||
| class UnknownModelError(BaseConverterError): | |||||
| """The unknown model error.""" | """The unknown model error.""" | ||||
| ERROR_CODE = BaseConverterError.ErrCode.UNKNOWN_MODEL.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | |||||
| user_msg=msg) | |||||
| super(UnknownModelError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| return cls | return cls | ||||
| class GraphInitFail(MindConverterException): | |||||
| class GraphInitError(MindConverterException): | |||||
| """The graph init fail error.""" | """The graph init fail error.""" | ||||
| def __init__(self, **kwargs): | |||||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | |||||
| user_msg=kwargs.get('msg', '')) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of GraphInitError.""" | |||||
| UNKNOWN_ERROR = 0 | |||||
| MODEL_NOT_SUPPORT = 1 | |||||
| TF_RUNTIME_ERROR = 2 | |||||
| INPUT_SHAPE_ERROR = 3 | |||||
| MI_RUNTIME_ERROR = 4 | |||||
| BASE_ERROR_CODE = ConverterErrors.GRAPH_INIT_FAIL.value | |||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||||
| DEFAULT_MSG = "Error occurred when init graph object." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(GraphInitError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (FileNotFoundError, | except_source = (FileNotFoundError, | ||||
| ModuleNotFoundError, | ModuleNotFoundError, | ||||
| ModelNotSupport, | |||||
| SubGraphSearchingFail, | |||||
| ModelNotSupportError, | |||||
| ModelLoadingError, | |||||
| RuntimeIntegrityError, | |||||
| TypeError, | TypeError, | ||||
| ZeroDivisionError, | ZeroDivisionError, | ||||
| RuntimeError, | RuntimeError, | ||||
| @@ -222,45 +243,65 @@ class GraphInitFail(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class TreeCreateFail(MindConverterException): | |||||
| class TreeCreationError(MindConverterException): | |||||
| """The tree create fail.""" | """The tree create fail.""" | ||||
| def __init__(self, msg): | |||||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | |||||
| user_msg=msg) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of TreeCreationError.""" | |||||
| UNKNOWN_ERROR = 0 | |||||
| NODE_INPUT_MISSING = 1 | |||||
| TREE_NODE_INSERT_FAIL = 2 | |||||
| BASE_ERROR_CODE = ConverterErrors.TREE_CREATE_FAIL.value | |||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||||
| DEFAULT_MSG = "Error occurred when create hierarchical tree." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(TreeCreationError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (NodeInputMissing, | |||||
| TreeNodeInsertFail, cls) | |||||
| except_source = NodeInputMissingError, TreeNodeInsertError, cls | |||||
| return except_source | return except_source | ||||
| class SourceFilesSaveFail(MindConverterException): | |||||
| class SourceFilesSaveError(MindConverterException): | |||||
| """The source files save fail error.""" | """The source files save fail error.""" | ||||
| def __init__(self, msg): | |||||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | |||||
| user_msg=msg) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of SourceFilesSaveError.""" | |||||
| UNKNOWN_ERROR = 0 | |||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 1 | |||||
| SCRIPT_GENERATE_FAIL = 2 | |||||
| REPORT_GENERATE_FAIL = 3 | |||||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | |||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||||
| DEFAULT_MSG = "Error occurred when save source files." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(SourceFilesSaveError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (NodeInputTypeNotSupport, | |||||
| ScriptGenerateFail, | |||||
| ReportGenerateFail, | |||||
| except_source = (NodeInputTypeNotSupportError, | |||||
| ScriptGenerationError, | |||||
| ReportGenerationError, | |||||
| IOError, cls) | IOError, cls) | ||||
| return except_source | return except_source | ||||
| class ModelNotSupport(MindConverterException): | |||||
| class ModelNotSupportError(GraphInitError): | |||||
| """The model not support error.""" | """The model not support error.""" | ||||
| ERROR_CODE = GraphInitError.ErrCode.MODEL_NOT_SUPPORT.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||||
| super(ModelNotSupportError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -275,13 +316,13 @@ class ModelNotSupport(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class TfRuntimeError(MindConverterException): | |||||
| class TfRuntimeError(GraphInitError): | |||||
| """Catch tf runtime error.""" | """Catch tf runtime error.""" | ||||
| ERROR_CODE = GraphInitError.ErrCode.TF_RUNTIME_ERROR.value | |||||
| DEFAULT_MSG = "Error occurred when init graph, TensorFlow runtime error." | |||||
| def __init__(self, msg): | |||||
| super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(TfRuntimeError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -290,26 +331,36 @@ class TfRuntimeError(MindConverterException): | |||||
| return tf_error, ValueError, RuntimeError, cls | return tf_error, ValueError, RuntimeError, cls | ||||
| class NodeInputMissing(MindConverterException): | |||||
| class RuntimeIntegrityError(GraphInitError): | |||||
| """Catch runtime error.""" | |||||
| ERROR_CODE = GraphInitError.ErrCode.MI_RUNTIME_ERROR.value | |||||
| def __init__(self, msg): | |||||
| super(RuntimeIntegrityError, self).__init__(msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls | |||||
| class NodeInputMissingError(TreeCreationError): | |||||
| """The node input missing error.""" | """The node input missing error.""" | ||||
| ERROR_CODE = TreeCreationError.ErrCode.NODE_INPUT_MISSING.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||||
| super(NodeInputMissingError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| return ValueError, IndexError, KeyError, AttributeError, cls | return ValueError, IndexError, KeyError, AttributeError, cls | ||||
| class TreeNodeInsertFail(MindConverterException): | |||||
| class TreeNodeInsertError(TreeCreationError): | |||||
| """The tree node create fail error.""" | """The tree node create fail error.""" | ||||
| ERROR_CODE = TreeCreationError.ErrCode.TREE_NODE_INSERT_FAIL.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||||
| super(TreeNodeInsertError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -321,26 +372,24 @@ class TreeNodeInsertFail(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class NodeInputTypeNotSupport(MindConverterException): | |||||
| class NodeInputTypeNotSupportError(SourceFilesSaveError): | |||||
| """The node input type NOT support error.""" | """The node input type NOT support error.""" | ||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| super(NodeInputTypeNotSupportError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| return ValueError, TypeError, IndexError, cls | return ValueError, TypeError, IndexError, cls | ||||
| class ScriptGenerateFail(MindConverterException): | |||||
| class ScriptGenerationError(SourceFilesSaveError): | |||||
| """The script generate fail error.""" | """The script generate fail error.""" | ||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| super(ScriptGenerationError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -351,13 +400,12 @@ class ScriptGenerateFail(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class ReportGenerateFail(MindConverterException): | |||||
| class ReportGenerationError(SourceFilesSaveError): | |||||
| """The report generate fail error.""" | """The report generate fail error.""" | ||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| super(ReportGenerationError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -365,13 +413,22 @@ class ReportGenerateFail(MindConverterException): | |||||
| return ZeroDivisionError, cls | return ZeroDivisionError, cls | ||||
| class SubGraphSearchingFail(MindConverterException): | |||||
| class SubGraphSearchingError(MindConverterException): | |||||
| """Sub-graph searching exception.""" | """Sub-graph searching exception.""" | ||||
| def __init__(self, msg): | |||||
| super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||||
| cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value, | |||||
| user_msg=msg) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of SourceFilesSaveError.""" | |||||
| BASE_ERROR = 0 | |||||
| CANNOT_FIND_VALID_PATTERN = 1 | |||||
| MODEL_NOT_SUPPORT = 2 | |||||
| BASE_ERROR_CODE = ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value | |||||
| ERROR_CODE = ErrCode.BASE_ERROR.value | |||||
| DEFAULT_MSG = "Sub-Graph pattern searching fail." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(SubGraphSearchingError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -379,13 +436,22 @@ class SubGraphSearchingFail(MindConverterException): | |||||
| return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls | return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls | ||||
| class GeneratorFail(MindConverterException): | |||||
| class GeneratorError(MindConverterException): | |||||
| """The Generator fail error.""" | """The Generator fail error.""" | ||||
| def __init__(self, msg): | |||||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | |||||
| @unique | |||||
| class ErrCode(Enum): | |||||
| """Define error code of SourceFilesSaveError.""" | |||||
| BASE_ERROR = 0 | |||||
| STATEMENT_GENERATION_ERROR = 1 | |||||
| CONVERTED_OPERATOR_LOADING_ERROR = 2 | |||||
| BASE_ERROR_CODE = ConverterErrors.GENERATOR_FAIL.value | |||||
| ERROR_CODE = ErrCode.BASE_ERROR.value | |||||
| DEFAULT_MSG = "Error occurred when generate code." | |||||
| def __init__(self, msg=DEFAULT_MSG): | |||||
| super(GeneratorError, self).__init__(user_msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -394,13 +460,12 @@ class GeneratorFail(MindConverterException): | |||||
| return except_source | return except_source | ||||
| class ModelLoadingFail(MindConverterException): | |||||
| class ModelLoadingError(GraphInitError): | |||||
| """Model loading fail.""" | """Model loading fail.""" | ||||
| ERROR_CODE = GraphInitError.ErrCode.INPUT_SHAPE_ERROR.value | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR, | |||||
| cls_code=ConverterErrors.MODEL_LOADING_FAIL.value, | |||||
| user_msg=msg) | |||||
| super(ModelLoadingError, self).__init__(msg=msg) | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| @@ -87,7 +87,7 @@ def get_imported_module(): | |||||
| str, imported module. | str, imported module. | ||||
| """ | """ | ||||
| return f"import numpy as np{NEW_LINE}" \ | return f"import numpy as np{NEW_LINE}" \ | ||||
| f"import mindspore{NEW_LINE}" \ | |||||
| f"from mindspore import nn{NEW_LINE}" \ | |||||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||||
| f"import mindspore{NEW_LINE}" \ | |||||
| f"from mindspore import nn{NEW_LINE}" \ | |||||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||||
| @@ -27,8 +27,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEAD | |||||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | |||||
| BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError | |||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | |||||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError | |||||
| from mindinsight.utils.exceptions import ParamMissError | from mindinsight.utils.exceptions import ParamMissError | ||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| @@ -67,13 +67,13 @@ def torch_installation_validation(func): | |||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| if not find_spec("torch"): | if not find_spec("torch"): | ||||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | |||||
| "scripts converter, and PyTorch vision must " | |||||
| "be consisted with model generation runtime.") | |||||
| log.error(str(error)) | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| error = RuntimeIntegrityError("PyTorch is required when using graph based " | |||||
| "scripts converter, and PyTorch vision must " | |||||
| "be consisted with model generation runtime.") | |||||
| log.error(error) | |||||
| log_console.error("\n") | |||||
| log_console.error(str(error)) | log_console.error(str(error)) | ||||
| log_console.error(detail_info) | |||||
| log_console.error("\n") | |||||
| sys.exit(0) | sys.exit(0) | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| @@ -97,17 +97,17 @@ def tf_installation_validation(func): | |||||
| output_folder: str, report_folder: str = None, | output_folder: str, report_folder: str = None, | ||||
| input_nodes: str = None, output_nodes: str = None): | input_nodes: str = None, output_nodes: str = None): | ||||
| # Check whether tensorflow is installed. | # Check whether tensorflow is installed. | ||||
| if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \ | |||||
| or not find_spec("onnxruntime"): | |||||
| error = ModuleNotFoundError( | |||||
| if not find_spec("tensorflow") or not find_spec("tensorflow-gpu") or not find_spec("tf2onnx") \ | |||||
| or not find_spec("onnx") or not find_spec("onnxruntime"): | |||||
| error = RuntimeIntegrityError( | |||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | ||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | ||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(str(error)) | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| log.error(error) | |||||
| log_console.error("\n") | |||||
| log_console.error(str(error)) | log_console.error(str(error)) | ||||
| log_console.error(detail_info) | |||||
| log_console.error("\n") | |||||
| sys.exit(0) | sys.exit(0) | ||||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | ||||
| @@ -116,15 +116,15 @@ def tf_installation_validation(func): | |||||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | ||||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | ||||
| or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): | or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): | ||||
| error = ModuleNotFoundError( | |||||
| error = RuntimeIntegrityError( | |||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | ||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | ||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(str(error)) | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| log.error(error) | |||||
| log_console.error("\n") | |||||
| log_console.error(str(error)) | log_console.error(str(error)) | ||||
| log_console.error(detail_info) | |||||
| log_console.error("\n") | |||||
| sys.exit(0) | sys.exit(0) | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| @@ -150,14 +150,14 @@ def _extract_model_name(model_path): | |||||
| @torch_installation_validation | @torch_installation_validation | ||||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||||
| @GraphInitError.uniform_catcher() | |||||
| @TreeCreationError.uniform_catcher() | |||||
| @SourceFilesSaveError.uniform_catcher() | |||||
| @GeneratorError.uniform_catcher() | |||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| """ | """ | ||||
| Pytoch to MindSpore based on Graph. | |||||
| PyTorch to MindSpore based on Graph. | |||||
| Args: | Args: | ||||
| graph_path (str): Graph file path. | graph_path (str): Graph file path. | ||||
| @@ -185,11 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| @tf_installation_validation | @tf_installation_validation | ||||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||||
| @TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.") | |||||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||||
| @GraphInitError.uniform_catcher() | |||||
| @TfRuntimeError.uniform_catcher() | |||||
| @TreeCreationError.uniform_catcher() | |||||
| @SourceFilesSaveError.uniform_catcher() | |||||
| @GeneratorError.uniform_catcher() | |||||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | input_nodes: str, output_nodes: str, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| @@ -221,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | ||||
| @BaseConverterFail.uniform_catcher("Failed to start base converter.") | |||||
| @BaseConverterError.uniform_catcher() | |||||
| def main_graph_base_converter(file_config): | def main_graph_base_converter(file_config): | ||||
| """ | """ | ||||
| The entrance for converter, script files will be converted. | The entrance for converter, script files will be converted. | ||||
| @@ -248,7 +248,7 @@ def main_graph_base_converter(file_config): | |||||
| report_folder=file_config['report_dir']) | report_folder=file_config['report_dir']) | ||||
| else: | else: | ||||
| error_msg = "Get UNSUPPORTED model." | error_msg = "Get UNSUPPORTED model." | ||||
| error = UnknownModel(error_msg) | |||||
| error = UnknownModelError(error_msg) | |||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise error | raise error | ||||
| @@ -263,7 +263,7 @@ def get_framework_type(model_path): | |||||
| framework_type = FrameworkType.TENSORFLOW.value | framework_type = FrameworkType.TENSORFLOW.value | ||||
| except IOError: | except IOError: | ||||
| error_msg = "Get UNSUPPORTED model." | error_msg = "Get UNSUPPORTED model." | ||||
| error = UnknownModel(error_msg) | |||||
| error = UnknownModelError(error_msg) | |||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise error | raise error | ||||
| @@ -23,7 +23,7 @@ from .node_struct import NodeStruct | |||||
| from .module_struct import ModuleStruct | from .module_struct import ModuleStruct | ||||
| from .args_translator import ArgsTranslationHelper | from .args_translator import ArgsTranslationHelper | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ...common.exceptions import GeneratorFail | |||||
| from ...common.exceptions import GeneratorError | |||||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | ||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| @@ -171,7 +171,7 @@ class Generator: | |||||
| self._global_context.node_struct_collections = self._node_struct_collections | self._global_context.node_struct_collections = self._node_struct_collections | ||||
| self._repeated_submodules = set() | self._repeated_submodules = set() | ||||
| @GeneratorFail.check_except("Generator occurs an error when forming base submodules.") | |||||
| @GeneratorError.check_except("Generator occurs an error when forming base submodules.") | |||||
| def _form_bottom_submodule(self): | def _form_bottom_submodule(self): | ||||
| """Form the basic submodules, which only contains nodes.""" | """Form the basic submodules, which only contains nodes.""" | ||||
| # Form module map | # Form module map | ||||
| @@ -351,7 +351,7 @@ class Generator: | |||||
| self._global_context.add_module_struct(sub.pattern_id, sub) | self._global_context.add_module_struct(sub.pattern_id, sub) | ||||
| depth -= 1 | depth -= 1 | ||||
| @GeneratorFail.check_except("Generator occurs an error when building modules.") | |||||
| @GeneratorError.check_except("Generator occurs an error when building modules.") | |||||
| def _recursive_form_module(self): | def _recursive_form_module(self): | ||||
| """Main routine in generator to build modules from bottom to top.""" | """Main routine in generator to build modules from bottom to top.""" | ||||
| # 1. List repeated submodules | # 1. List repeated submodules | ||||
| @@ -474,7 +474,7 @@ class Generator: | |||||
| """Return all ModuleStructs in this model.""" | """Return all ModuleStructs in this model.""" | ||||
| return self._module_struct_collections | return self._module_struct_collections | ||||
| @GeneratorFail.check_except("Generator occurs an error when generating code statements.") | |||||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | |||||
| def generate(self): | def generate(self): | ||||
| """ | """ | ||||
| Generate the final script file. | Generate the final script file. | ||||
| @@ -22,7 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ..constant import InputType | from ..constant import InputType | ||||
| from ...common.exceptions import GeneratorFail | |||||
| from ...common.exceptions import GeneratorError | |||||
| class NodeStruct: | class NodeStruct: | ||||
| @@ -146,7 +146,7 @@ class NodeStruct: | |||||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | ||||
| self.scope = Scope(parsed_scope) | self.scope = Scope(parsed_scope) | ||||
| @GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.") | |||||
| @GeneratorError.check_except("Generator occurs an error when initializing node's args translator.") | |||||
| def init_args_translator(self, translated_args: list): | def init_args_translator(self, translated_args: list): | ||||
| """ | """ | ||||
| Initialize the ArgsTranslator for each Node. | Initialize the ArgsTranslator for each Node. | ||||
| @@ -170,7 +170,7 @@ class NodeStruct: | |||||
| self.ms_op]): | self.ms_op]): | ||||
| self.ready_to_generate = True | self.ready_to_generate = True | ||||
| @GeneratorFail.check_except("Generator occurs an error when creating node struct.") | |||||
| @GeneratorError.check_except("Generator occurs an error when creating node struct.") | |||||
| def update(self, arg, force_ready=False): | def update(self, arg, force_ready=False): | ||||
| """ | """ | ||||
| Pass Node info. to generator NodeStruct. | Pass Node info. to generator NodeStruct. | ||||
| @@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from .hierarchical_tree import HierarchicalTree | from .hierarchical_tree import HierarchicalTree | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail | |||||
| from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError | |||||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | ||||
| @@ -53,7 +53,7 @@ class HierarchicalTreeFactory: | |||||
| """Hierarchical tree factory.""" | """Hierarchical tree factory.""" | ||||
| @classmethod | @classmethod | ||||
| @TreeNodeInsertFail.check_except("Tree node inserts failed.") | |||||
| @TreeNodeInsertError.check_except("Tree node inserts failed.") | |||||
| def create(cls, graph): | def create(cls, graph): | ||||
| """ | """ | ||||
| Factory method of hierarchical tree. | Factory method of hierarchical tree. | ||||
| @@ -73,7 +73,7 @@ class HierarchicalTreeFactory: | |||||
| if node_input != 0 and not node_input: | if node_input != 0 and not node_input: | ||||
| err_msg = f"This model is not supported now. " \ | err_msg = f"This model is not supported now. " \ | ||||
| f"Cannot find {node_name}'s input shape." | f"Cannot find {node_name}'s input shape." | ||||
| error = NodeInputMissing(err_msg) | |||||
| error = NodeInputMissingError(err_msg) | |||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise error | raise error | ||||
| if isinstance(node_inst, OnnxGraphNode): | if isinstance(node_inst, OnnxGraphNode): | ||||
| @@ -35,7 +35,7 @@ from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| from ..constant import NodeType | from ..constant import NodeType | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| from ...common.exceptions import ReportGenerateFail, ScriptGenerateFail, NodeInputTypeNotSupport | |||||
| from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError | |||||
| class HierarchicalTree(Tree): | class HierarchicalTree(Tree): | ||||
| @@ -189,7 +189,7 @@ class HierarchicalTree(Tree): | |||||
| try: | try: | ||||
| self._adjust_structure() | self._adjust_structure() | ||||
| code_fragments = self._generate_codes(mapper) | code_fragments = self._generate_codes(mapper) | ||||
| except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e: | |||||
| except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e: | |||||
| log.error("Error occur when generating codes.") | log.error("Error occur when generating codes.") | ||||
| raise e | raise e | ||||
| @@ -264,8 +264,8 @@ class HierarchicalTree(Tree): | |||||
| node.data.args_in_code.pop(arg) | node.data.args_in_code.pop(arg) | ||||
| return node | return node | ||||
| @ScriptGenerateFail.check_except("FormatCode run error. Check detailed information in log.") | |||||
| @ReportGenerateFail.check_except("Not find valid operators in converted script.") | |||||
| @ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.") | |||||
| @ReportGenerationError.check_except("Not find valid operators in converted script.") | |||||
| def _generate_codes(self, mapper): | def _generate_codes(self, mapper): | ||||
| """ | """ | ||||
| Generate code files. | Generate code files. | ||||
| @@ -21,7 +21,7 @@ from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ..third_party_graph.onnx_utils import BaseNode | from ..third_party_graph.onnx_utils import BaseNode | ||||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | ||||
| from ...common.exceptions import SubGraphSearchingFail | |||||
| from ...common.exceptions import SubGraphSearchingError | |||||
| def _is_satisfied(path): | def _is_satisfied(path): | ||||
| @@ -267,7 +267,7 @@ def _add_known_module_name(search_path): | |||||
| ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name | ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name | ||||
| @SubGraphSearchingFail.check_except("Sub-Graph searching fail.") | |||||
| @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") | |||||
| def generate_scope_name(data_loader): | def generate_scope_name(data_loader): | ||||
| """ | """ | ||||
| Generate scope name according to computation graph. | Generate scope name according to computation graph. | ||||
| @@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from ..common.code_fragment import CodeFragment | from ..common.code_fragment import CodeFragment | ||||
| from ..constant import NodeType, InputType | from ..constant import NodeType, InputType | ||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| from ...common.exceptions import NodeInputTypeNotSupportError | |||||
| class GraphParser(metaclass=abc.ABCMeta): | class GraphParser(metaclass=abc.ABCMeta): | ||||
| @@ -522,7 +522,7 @@ class GraphNode(abc.ABC): | |||||
| elif input_type == InputType.LIST.value: | elif input_type == InputType.LIST.value: | ||||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct},)" | ipt_args_settings_in_construct = f"({ipt_args_in_construct},)" | ||||
| else: | else: | ||||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||||
| raise NodeInputTypeNotSupportError(f"Input type[{input_type}] is not supported now.") | |||||
| else: | else: | ||||
| ipt_args_settings_in_construct = ipt_args_in_construct | ipt_args_settings_in_construct = ipt_args_in_construct | ||||
| @@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext | |||||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | ||||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ||||
| from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail | |||||
| from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError | |||||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | ||||
| @@ -308,7 +308,7 @@ class OnnxDataLoader: | |||||
| w = int(match.group('w')) | w = int(match.group('w')) | ||||
| c = int(match.group('c')) | c = int(match.group('c')) | ||||
| if [h, w, c] != list(self.graph_input_shape)[1:4]: | if [h, w, c] != list(self.graph_input_shape)[1:4]: | ||||
| raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||||
| raise ModelLoadingError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -383,7 +383,7 @@ class OnnxDataLoader: | |||||
| self._nodes_dict[n.name] = n | self._nodes_dict[n.name] = n | ||||
| nodes_topo_idx.append((idx, n.name)) | nodes_topo_idx.append((idx, n.name)) | ||||
| if len(node.output) > 1: | if len(node.output) > 1: | ||||
| raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") | |||||
| raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | |||||
| self.output_name_to_node_name[node.output[0]] = node.name | self.output_name_to_node_name[node.output[0]] = node.name | ||||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | ||||
| node_inputs = [i.replace(":0", "") for i in node.input] | node_inputs = [i.replace(":0", "") for i in node.input] | ||||
| @@ -423,7 +423,7 @@ class OnnxDataLoader: | |||||
| shape[i] = int(shape[i]) | shape[i] = int(shape[i]) | ||||
| node_name = self.output_name_to_node_name[node_opt_name] | node_name = self.output_name_to_node_name[node_opt_name] | ||||
| if not node_name: | if not node_name: | ||||
| raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.") | |||||
| raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") | |||||
| self.node_output_shape_dict[node_name] = shape | self.node_output_shape_dict[node_name] = shape | ||||
| def get_node(self, node_name): | def get_node(self, node_name): | ||||
| @@ -18,14 +18,14 @@ from importlib import import_module | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | from .base import GraphParser | ||||
| from ...common.exceptions import ModelNotSupport | |||||
| from ...common.exceptions import ModelNotSupportError | |||||
| class PyTorchGraphParser(GraphParser): | class PyTorchGraphParser(GraphParser): | ||||
| """Define pytorch graph parser.""" | """Define pytorch graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except( | |||||
| @ModelNotSupportError.check_except( | |||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | "Error occurs in loading model, please check your model or runtime environment integrity." | ||||
| ) | ) | ||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| @@ -18,14 +18,14 @@ from importlib import import_module | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | from .base import GraphParser | ||||
| from ...common.exceptions import ModelNotSupport | |||||
| from ...common.exceptions import ModelNotSupportError | |||||
| class TFGraphParser(GraphParser): | class TFGraphParser(GraphParser): | ||||
| """Define TF graph parser.""" | """Define TF graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except( | |||||
| @ModelNotSupportError.check_except( | |||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | "Error occurs in loading model, please check your model or runtime environment integrity." | ||||
| ) | ) | ||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||