Browse Source

!1013 Edit error code and refactor exception module

From: @liuchongming74
Reviewed-by: @lilongfei15,@wenkai_dist,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2ec4f1eb03
12 changed files with 221 additions and 156 deletions
  1. +160
    -95
      mindinsight/mindconverter/common/exceptions.py
  2. +4
    -4
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +31
    -31
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +4
    -4
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  5. +3
    -3
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  6. +3
    -3
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  7. +4
    -4
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  8. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  9. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  10. +4
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 160
- 95
mindinsight/mindconverter/common/exceptions.py View File

@@ -15,7 +15,7 @@
"""Define custom exception.""" """Define custom exception."""
import abc import abc
import sys import sys
from enum import unique
from enum import unique, Enum
from importlib import import_module from importlib import import_module


from lib2to3.pgen2 import parse from lib2to3.pgen2 import parse
@@ -33,16 +33,6 @@ class ConverterErrors(ScriptConverterErrors):
SCRIPT_NOT_SUPPORT = 1 SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2 NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3 CODE_SYNTAX_ERROR = 3
NODE_INPUT_TYPE_NOT_SUPPORT = 4
NODE_INPUT_MISSING = 5
TREE_NODE_INSERT_FAIL = 6
UNKNOWN_MODEL = 7
MODEL_NOT_SUPPORT = 8
SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11
INPUT_SHAPE_ERROR = 12
TF_RUNTIME_ERROR = 13


BASE_CONVERTER_FAIL = 000 BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100 GRAPH_INIT_FAIL = 100
@@ -50,7 +40,6 @@ class ConverterErrors(ScriptConverterErrors):
SOURCE_FILES_SAVE_FAIL = 300 SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400 GENERATOR_FAIL = 400
SUB_GRAPH_SEARCHING_FAIL = 500 SUB_GRAPH_SEARCHING_FAIL = 500
MODEL_LOADING_FAIL = 600




class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
@@ -82,21 +71,20 @@ class CodeSyntaxError(MindInsightException):


class MindConverterException(Exception): class MindConverterException(Exception):
"""MindConverter exception.""" """MindConverter exception."""
BASE_ERROR_CODE = None # ConverterErrors.BASE_CONVERTER_FAIL.value
# ERROR_CODE should be declared in child exception.
ERROR_CODE = None


def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialization of MindInsightException.""" """Initialization of MindInsightException."""
error = kwargs.get('error', None)
user_msg = kwargs.get('user_msg', '') user_msg = kwargs.get('user_msg', '')
debug_msg = kwargs.get('debug_msg', '')
cls_code = kwargs.get('cls_code', 0)


if isinstance(user_msg, str): if isinstance(user_msg, str):
user_msg = ' '.join(user_msg.split()) user_msg = ' '.join(user_msg.split())

super(MindConverterException, self).__init__() super(MindConverterException, self).__init__()
self.error = error
self.user_msg = user_msg self.user_msg = user_msg
self.debug_msg = debug_msg
self.cls_code = cls_code
self.root_exception_error_code = None


def __str__(self): def __str__(self):
return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg)
@@ -116,8 +104,12 @@ class MindConverterException(Exception):
Returns: Returns:
str, Hex string representing the composed MindConverter error code. str, Hex string representing the composed MindConverter error code.
""" """
num = 0xFFFF & self.error.value
error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper()))
if self.root_exception_error_code:
return self.root_exception_error_code
if self.BASE_ERROR_CODE is None or self.ERROR_CODE is None:
raise ValueError("MindConverterException has not been initialized.")
num = 0xFFFF & self.ERROR_CODE # 0xFFFF & self.error.value
error_code = f"{str(self.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}"
return error_code return error_code


@classmethod @classmethod
@@ -126,7 +118,7 @@ class MindConverterException(Exception):
"""Raise from below exceptions.""" """Raise from below exceptions."""


@classmethod @classmethod
def uniform_catcher(cls, msg):
def uniform_catcher(cls, msg: str = ""):
"""Uniform exception catcher.""" """Uniform exception catcher."""


def decorator(func): def decorator(func):
@@ -134,16 +126,20 @@ class MindConverterException(Exception):
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
except cls.raise_from() as e: except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
error = cls() if not msg else cls(msg=msg)
detail_info = str(e)
log.error(error)
log_console.error("\n")
log_console.error(detail_info) log_console.error(detail_info)
log_console.error("\n")
log.exception(e) log.exception(e)
sys.exit(0) sys.exit(0)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
detail_info = f"Error detail: Required package not found, please check the runtime environment."
detail_info = "Error detail: Required package not found, please check the runtime environment."
log_console.error("\n")
log_console.error(str(e)) log_console.error(str(e))
log_console.error(detail_info) log_console.error(detail_info)
log_console.error("\n")
log.exception(e) log.exception(e)
sys.exit(0) sys.exit(0)
return res return res
@@ -161,9 +157,12 @@ class MindConverterException(Exception):
try: try:
output = func(*args, **kwargs) output = func(*args, **kwargs)
except cls.raise_from() as e: except cls.raise_from() as e:
error = cls(msg=msg)
error_code = e.error_code() if isinstance(e, MindConverterException) else None
error.root_exception_error_code = error_code
log.error(msg) log.error(msg)
log.exception(e) log.exception(e)
raise cls(msg=msg)
raise error
except Exception as e: except Exception as e:
log.error(msg) log.error(msg)
log.exception(e) log.exception(e)
@@ -175,12 +174,21 @@ class MindConverterException(Exception):
return decorator return decorator




class BaseConverterFail(MindConverterException):
class BaseConverterError(MindConverterException):
"""Base converter failed.""" """Base converter failed."""


def __init__(self, msg):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of BaseConverterError."""
UNKNOWN_ERROR = 0
UNKNOWN_MODEL = 1

BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Failed to start base converter."

def __init__(self, msg=DEFAULT_MSG):
super(BaseConverterError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -189,32 +197,45 @@ class BaseConverterFail(MindConverterException):
return except_source return except_source




class UnknownModel(MindConverterException):
class UnknownModelError(BaseConverterError):
"""The unknown model error.""" """The unknown model error."""
ERROR_CODE = BaseConverterError.ErrCode.UNKNOWN_MODEL.value


def __init__(self, msg): def __init__(self, msg):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg)
super(UnknownModelError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
return cls return cls




class GraphInitFail(MindConverterException):
class GraphInitError(MindConverterException):
"""The graph init fail error.""" """The graph init fail error."""


def __init__(self, **kwargs):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', ''))
@unique
class ErrCode(Enum):
"""Define error code of GraphInitError."""
UNKNOWN_ERROR = 0
MODEL_NOT_SUPPORT = 1
TF_RUNTIME_ERROR = 2
INPUT_SHAPE_ERROR = 3
MI_RUNTIME_ERROR = 4

BASE_ERROR_CODE = ConverterErrors.GRAPH_INIT_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when init graph object."

def __init__(self, msg=DEFAULT_MSG):
super(GraphInitError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (FileNotFoundError, except_source = (FileNotFoundError,
ModuleNotFoundError, ModuleNotFoundError,
ModelNotSupport,
SubGraphSearchingFail,
ModelNotSupportError,
ModelLoadingError,
RuntimeIntegrityError,
TypeError, TypeError,
ZeroDivisionError, ZeroDivisionError,
RuntimeError, RuntimeError,
@@ -222,45 +243,65 @@ class GraphInitFail(MindConverterException):
return except_source return except_source




class TreeCreateFail(MindConverterException):
class TreeCreationError(MindConverterException):
"""The tree create fail.""" """The tree create fail."""


def __init__(self, msg):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of TreeCreationError."""
UNKNOWN_ERROR = 0
NODE_INPUT_MISSING = 1
TREE_NODE_INSERT_FAIL = 2

BASE_ERROR_CODE = ConverterErrors.TREE_CREATE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when create hierarchical tree."

def __init__(self, msg=DEFAULT_MSG):
super(TreeCreationError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (NodeInputMissing,
TreeNodeInsertFail, cls)
except_source = NodeInputMissingError, TreeNodeInsertError, cls
return except_source return except_source




class SourceFilesSaveFail(MindConverterException):
class SourceFilesSaveError(MindConverterException):
"""The source files save fail error.""" """The source files save fail error."""


def __init__(self, msg):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
UNKNOWN_ERROR = 0
NODE_INPUT_TYPE_NOT_SUPPORT = 1
SCRIPT_GENERATE_FAIL = 2
REPORT_GENERATE_FAIL = 3

BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when save source files."

def __init__(self, msg=DEFAULT_MSG):
super(SourceFilesSaveError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (NodeInputTypeNotSupport,
ScriptGenerateFail,
ReportGenerateFail,
except_source = (NodeInputTypeNotSupportError,
ScriptGenerationError,
ReportGenerationError,
IOError, cls) IOError, cls)
return except_source return except_source




class ModelNotSupport(MindConverterException):
class ModelNotSupportError(GraphInitError):
"""The model not support error.""" """The model not support error."""


ERROR_CODE = GraphInitError.ErrCode.MODEL_NOT_SUPPORT.value

def __init__(self, msg): def __init__(self, msg):
super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)
super(ModelNotSupportError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -275,13 +316,13 @@ class ModelNotSupport(MindConverterException):
return except_source return except_source




class TfRuntimeError(MindConverterException):
class TfRuntimeError(GraphInitError):
"""Catch tf runtime error.""" """Catch tf runtime error."""
ERROR_CODE = GraphInitError.ErrCode.TF_RUNTIME_ERROR.value
DEFAULT_MSG = "Error occurred when init graph, TensorFlow runtime error."


def __init__(self, msg):
super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)
def __init__(self, msg=DEFAULT_MSG):
super(TfRuntimeError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -290,26 +331,36 @@ class TfRuntimeError(MindConverterException):
return tf_error, ValueError, RuntimeError, cls return tf_error, ValueError, RuntimeError, cls




class NodeInputMissing(MindConverterException):
class RuntimeIntegrityError(GraphInitError):
"""Catch runtime error."""
ERROR_CODE = GraphInitError.ErrCode.MI_RUNTIME_ERROR.value

def __init__(self, msg):
super(RuntimeIntegrityError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls


class NodeInputMissingError(TreeCreationError):
"""The node input missing error.""" """The node input missing error."""
ERROR_CODE = TreeCreationError.ErrCode.NODE_INPUT_MISSING.value


def __init__(self, msg): def __init__(self, msg):
super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING,
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)
super(NodeInputMissingError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
return ValueError, IndexError, KeyError, AttributeError, cls return ValueError, IndexError, KeyError, AttributeError, cls




class TreeNodeInsertFail(MindConverterException):
class TreeNodeInsertError(TreeCreationError):
"""The tree node create fail error.""" """The tree node create fail error."""
ERROR_CODE = TreeCreationError.ErrCode.TREE_NODE_INSERT_FAIL.value


def __init__(self, msg): def __init__(self, msg):
super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL,
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)
super(TreeNodeInsertError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -321,26 +372,24 @@ class TreeNodeInsertFail(MindConverterException):
return except_source return except_source




class NodeInputTypeNotSupport(MindConverterException):
class NodeInputTypeNotSupportError(SourceFilesSaveError):
"""The node input type NOT support error.""" """The node input type NOT support error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value


def __init__(self, msg): def __init__(self, msg):
super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(NodeInputTypeNotSupportError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
return ValueError, TypeError, IndexError, cls return ValueError, TypeError, IndexError, cls




class ScriptGenerateFail(MindConverterException):
class ScriptGenerationError(SourceFilesSaveError):
"""The script generate fail error.""" """The script generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value


def __init__(self, msg): def __init__(self, msg):
super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(ScriptGenerationError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -351,13 +400,12 @@ class ScriptGenerateFail(MindConverterException):
return except_source return except_source




class ReportGenerateFail(MindConverterException):
class ReportGenerationError(SourceFilesSaveError):
"""The report generate fail error.""" """The report generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value


def __init__(self, msg): def __init__(self, msg):
super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(ReportGenerationError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -365,13 +413,22 @@ class ReportGenerateFail(MindConverterException):
return ZeroDivisionError, cls return ZeroDivisionError, cls




class SubGraphSearchingFail(MindConverterException):
class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception.""" """Sub-graph searching exception."""


def __init__(self, msg):
super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
BASE_ERROR = 0
CANNOT_FIND_VALID_PATTERN = 1
MODEL_NOT_SUPPORT = 2

BASE_ERROR_CODE = ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value
ERROR_CODE = ErrCode.BASE_ERROR.value
DEFAULT_MSG = "Sub-Graph pattern searching fail."

def __init__(self, msg=DEFAULT_MSG):
super(SubGraphSearchingError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -379,13 +436,22 @@ class SubGraphSearchingFail(MindConverterException):
return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls




class GeneratorFail(MindConverterException):
class GeneratorError(MindConverterException):
"""The Generator fail error.""" """The Generator fail error."""


def __init__(self, msg):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
BASE_ERROR = 0
STATEMENT_GENERATION_ERROR = 1
CONVERTED_OPERATOR_LOADING_ERROR = 2

BASE_ERROR_CODE = ConverterErrors.GENERATOR_FAIL.value
ERROR_CODE = ErrCode.BASE_ERROR.value
DEFAULT_MSG = "Error occurred when generate code."

def __init__(self, msg=DEFAULT_MSG):
super(GeneratorError, self).__init__(user_msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
@@ -394,13 +460,12 @@ class GeneratorFail(MindConverterException):
return except_source return except_source




class ModelLoadingFail(MindConverterException):
class ModelLoadingError(GraphInitError):
"""Model loading fail.""" """Model loading fail."""
ERROR_CODE = GraphInitError.ErrCode.INPUT_SHAPE_ERROR.value


def __init__(self, msg): def __init__(self, msg):
super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR,
cls_code=ConverterErrors.MODEL_LOADING_FAIL.value,
user_msg=msg)
super(ModelLoadingError, self).__init__(msg=msg)


@classmethod @classmethod
def raise_from(cls): def raise_from(cls):


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -87,7 +87,7 @@ def get_imported_module():
str, imported module. str, imported module.
""" """
return f"import numpy as np{NEW_LINE}" \ return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 31
- 31
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -27,8 +27,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEAD
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \
BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError
from mindinsight.utils.exceptions import ParamMissError from mindinsight.utils.exceptions import ParamMissError


permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
@@ -67,13 +67,13 @@ def torch_installation_validation(func):
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
if not find_spec("torch"): if not find_spec("torch"):
error = ModuleNotFoundError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
error = RuntimeIntegrityError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(error)
log_console.error("\n")
log_console.error(str(error)) log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
@@ -97,17 +97,17 @@ def tf_installation_validation(func):
output_folder: str, report_folder: str = None, output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None): input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed. # Check whether tensorflow is installed.
if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \
or not find_spec("onnxruntime"):
error = ModuleNotFoundError(
if not find_spec("tensorflow") or not find_spec("tensorflow-gpu") or not find_spec("tf2onnx") \
or not find_spec("onnx") or not find_spec("onnxruntime"):
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
log.error(error)
log_console.error("\n")
log_console.error(str(error)) log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0) sys.exit(0)


onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
@@ -116,15 +116,15 @@ def tf_installation_validation(func):
if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
error = ModuleNotFoundError(
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
log.error(error)
log_console.error("\n")
log_console.error(str(error)) log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
@@ -150,14 +150,14 @@ def _extract_model_name(model_path):




@torch_installation_validation @torch_installation_validation
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
@GraphInitError.uniform_catcher()
@TreeCreationError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
""" """
Pytoch to MindSpore based on Graph.
PyTorch to MindSpore based on Graph.


Args: Args:
graph_path (str): Graph file path. graph_path (str): Graph file path.
@@ -185,11 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,




@tf_installation_validation @tf_installation_validation
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
@GraphInitError.uniform_catcher()
@TfRuntimeError.uniform_catcher()
@TreeCreationError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str, input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
@@ -221,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)




@BaseConverterFail.uniform_catcher("Failed to start base converter.")
@BaseConverterError.uniform_catcher()
def main_graph_base_converter(file_config): def main_graph_base_converter(file_config):
""" """
The entrance for converter, script files will be converted. The entrance for converter, script files will be converted.
@@ -248,7 +248,7 @@ def main_graph_base_converter(file_config):
report_folder=file_config['report_dir']) report_folder=file_config['report_dir'])
else: else:
error_msg = "Get UNSUPPORTED model." error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
error = UnknownModelError(error_msg)
log.error(str(error)) log.error(str(error))
raise error raise error


@@ -263,7 +263,7 @@ def get_framework_type(model_path):
framework_type = FrameworkType.TENSORFLOW.value framework_type = FrameworkType.TENSORFLOW.value
except IOError: except IOError:
error_msg = "Get UNSUPPORTED model." error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
error = UnknownModelError(error_msg)
log.error(str(error)) log.error(str(error))
raise error raise error




+ 4
- 4
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -23,7 +23,7 @@ from .node_struct import NodeStruct
from .module_struct import ModuleStruct from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorFail
from ...common.exceptions import GeneratorError
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator from ..report_generator import ReportGenerator
@@ -171,7 +171,7 @@ class Generator:
self._global_context.node_struct_collections = self._node_struct_collections self._global_context.node_struct_collections = self._node_struct_collections
self._repeated_submodules = set() self._repeated_submodules = set()


@GeneratorFail.check_except("Generator occurs an error when forming base submodules.")
@GeneratorError.check_except("Generator occurs an error when forming base submodules.")
def _form_bottom_submodule(self): def _form_bottom_submodule(self):
"""Form the basic submodules, which only contains nodes.""" """Form the basic submodules, which only contains nodes."""
# Form module map # Form module map
@@ -351,7 +351,7 @@ class Generator:
self._global_context.add_module_struct(sub.pattern_id, sub) self._global_context.add_module_struct(sub.pattern_id, sub)
depth -= 1 depth -= 1


@GeneratorFail.check_except("Generator occurs an error when building modules.")
@GeneratorError.check_except("Generator occurs an error when building modules.")
def _recursive_form_module(self): def _recursive_form_module(self):
"""Main routine in generator to build modules from bottom to top.""" """Main routine in generator to build modules from bottom to top."""
# 1. List repeated submodules # 1. List repeated submodules
@@ -474,7 +474,7 @@ class Generator:
"""Return all ModuleStructs in this model.""" """Return all ModuleStructs in this model."""
return self._module_struct_collections return self._module_struct_collections


@GeneratorFail.check_except("Generator occurs an error when generating code statements.")
@GeneratorError.check_except("Generator occurs an error when generating code statements.")
def generate(self): def generate(self):
""" """
Generate the final script file. Generate the final script file.


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -22,7 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..constant import InputType from ..constant import InputType
from ...common.exceptions import GeneratorFail
from ...common.exceptions import GeneratorError




class NodeStruct: class NodeStruct:
@@ -146,7 +146,7 @@ class NodeStruct:
parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier)
self.scope = Scope(parsed_scope) self.scope = Scope(parsed_scope)


@GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.")
@GeneratorError.check_except("Generator occurs an error when initializing node's args translator.")
def init_args_translator(self, translated_args: list): def init_args_translator(self, translated_args: list):
""" """
Initialize the ArgsTranslator for each Node. Initialize the ArgsTranslator for each Node.
@@ -170,7 +170,7 @@ class NodeStruct:
self.ms_op]): self.ms_op]):
self.ready_to_generate = True self.ready_to_generate = True


@GeneratorFail.check_except("Generator occurs an error when creating node struct.")
@GeneratorError.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False): def update(self, arg, force_ready=False):
""" """
Pass Node info. to generator NodeStruct. Pass Node info. to generator NodeStruct.


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode


from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail
from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError




def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
@@ -53,7 +53,7 @@ class HierarchicalTreeFactory:
"""Hierarchical tree factory.""" """Hierarchical tree factory."""


@classmethod @classmethod
@TreeNodeInsertFail.check_except("Tree node inserts failed.")
@TreeNodeInsertError.check_except("Tree node inserts failed.")
def create(cls, graph): def create(cls, graph):
""" """
Factory method of hierarchical tree. Factory method of hierarchical tree.
@@ -73,7 +73,7 @@ class HierarchicalTreeFactory:
if node_input != 0 and not node_input: if node_input != 0 and not node_input:
err_msg = f"This model is not supported now. " \ err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape." f"Cannot find {node_name}'s input shape."
error = NodeInputMissing(err_msg)
error = NodeInputMissingError(err_msg)
log.error(str(error)) log.error(str(error))
raise error raise error
if isinstance(node_inst, OnnxGraphNode): if isinstance(node_inst, OnnxGraphNode):


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -35,7 +35,7 @@ from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType from ..constant import NodeType
from ..report_generator import ReportGenerator from ..report_generator import ReportGenerator
from ...common.exceptions import ReportGenerateFail, ScriptGenerateFail, NodeInputTypeNotSupport
from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError




class HierarchicalTree(Tree): class HierarchicalTree(Tree):
@@ -189,7 +189,7 @@ class HierarchicalTree(Tree):
try: try:
self._adjust_structure() self._adjust_structure()
code_fragments = self._generate_codes(mapper) code_fragments = self._generate_codes(mapper)
except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e:
except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e:
log.error("Error occur when generating codes.") log.error("Error occur when generating codes.")
raise e raise e


@@ -264,8 +264,8 @@ class HierarchicalTree(Tree):
node.data.args_in_code.pop(arg) node.data.args_in_code.pop(arg)
return node return node


@ScriptGenerateFail.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerateFail.check_except("Not find valid operators in converted script.")
@ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerationError.check_except("Not find valid operators in converted script.")
def _generate_codes(self, mapper): def _generate_codes(self, mapper):
""" """
Generate code files. Generate code files.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -21,7 +21,7 @@ from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingFail
from ...common.exceptions import SubGraphSearchingError




def _is_satisfied(path): def _is_satisfied(path):
@@ -267,7 +267,7 @@ def _add_known_module_name(search_path):
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name




@SubGraphSearchingFail.check_except("Sub-Graph searching fail.")
@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
def generate_scope_name(data_loader): def generate_scope_name(data_loader):
""" """
Generate scope name according to computation graph. Generate scope name according to computation graph.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log
from ..common.code_fragment import CodeFragment from ..common.code_fragment import CodeFragment
from ..constant import NodeType, InputType from ..constant import NodeType, InputType
from ..mapper.base import Mapper from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport
from ...common.exceptions import NodeInputTypeNotSupportError




class GraphParser(metaclass=abc.ABCMeta): class GraphParser(metaclass=abc.ABCMeta):
@@ -522,7 +522,7 @@ class GraphNode(abc.ABC):
elif input_type == InputType.LIST.value: elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct},)" ipt_args_settings_in_construct = f"({ipt_args_in_construct},)"
else: else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
raise NodeInputTypeNotSupportError(f"Input type[{input_type}] is not supported now.")
else: else:
ipt_args_settings_in_construct = ipt_args_in_construct ipt_args_settings_in_construct = ipt_args_in_construct




+ 4
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext


from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail
from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError




def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -308,7 +308,7 @@ class OnnxDataLoader:
w = int(match.group('w')) w = int(match.group('w'))
c = int(match.group('c')) c = int(match.group('c'))
if [h, w, c] != list(self.graph_input_shape)[1:4]: if [h, w, c] != list(self.graph_input_shape)[1:4]:
raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
raise ModelLoadingError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
return True return True
return False return False


@@ -383,7 +383,7 @@ class OnnxDataLoader:
self._nodes_dict[n.name] = n self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name)) nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1: if len(node.output) > 1:
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.")
raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name self.output_name_to_node_name[node.output[0]] = node.name
self._global_context.onnx_node_name_to_topo_idx[n.name] = idx self._global_context.onnx_node_name_to_topo_idx[n.name] = idx
node_inputs = [i.replace(":0", "") for i in node.input] node_inputs = [i.replace(":0", "") for i in node.input]
@@ -423,7 +423,7 @@ class OnnxDataLoader:
shape[i] = int(shape[i]) shape[i] = int(shape[i])
node_name = self.output_name_to_node_name[node_opt_name] node_name = self.output_name_to_node_name[node_opt_name]
if not node_name: if not node_name:
raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.")
raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.")
self.node_output_shape_dict[node_name] = shape self.node_output_shape_dict[node_name] = shape


def get_node(self, node_name): def get_node(self, node_name):


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -18,14 +18,14 @@ from importlib import import_module


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser from .base import GraphParser
from ...common.exceptions import ModelNotSupport
from ...common.exceptions import ModelNotSupportError




class PyTorchGraphParser(GraphParser): class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser.""" """Define pytorch graph parser."""


@classmethod @classmethod
@ModelNotSupport.check_except(
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity." "Error occurs in loading model, please check your model or runtime environment integrity."
) )
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -18,14 +18,14 @@ from importlib import import_module


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser from .base import GraphParser
from ...common.exceptions import ModelNotSupport
from ...common.exceptions import ModelNotSupportError




class TFGraphParser(GraphParser): class TFGraphParser(GraphParser):
"""Define TF graph parser.""" """Define TF graph parser."""


@classmethod @classmethod
@ModelNotSupport.check_except(
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity." "Error occurs in loading model, please check your model or runtime environment integrity."
) )
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):


Loading…
Cancel
Save