Browse Source

!1013 Edit error code and refactor exception module

From: @liuchongming74
Reviewed-by: @lilongfei15,@wenkai_dist,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2ec4f1eb03
12 changed files with 221 additions and 156 deletions
  1. +160
    -95
      mindinsight/mindconverter/common/exceptions.py
  2. +4
    -4
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +31
    -31
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +4
    -4
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  5. +3
    -3
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  6. +3
    -3
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  7. +4
    -4
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  8. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  9. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  10. +4
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 160
- 95
mindinsight/mindconverter/common/exceptions.py View File

@@ -15,7 +15,7 @@
"""Define custom exception."""
import abc
import sys
from enum import unique
from enum import unique, Enum
from importlib import import_module

from lib2to3.pgen2 import parse
@@ -33,16 +33,6 @@ class ConverterErrors(ScriptConverterErrors):
SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
NODE_INPUT_TYPE_NOT_SUPPORT = 4
NODE_INPUT_MISSING = 5
TREE_NODE_INSERT_FAIL = 6
UNKNOWN_MODEL = 7
MODEL_NOT_SUPPORT = 8
SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11
INPUT_SHAPE_ERROR = 12
TF_RUNTIME_ERROR = 13

BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100
@@ -50,7 +40,6 @@ class ConverterErrors(ScriptConverterErrors):
SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400
SUB_GRAPH_SEARCHING_FAIL = 500
MODEL_LOADING_FAIL = 600


class ScriptNotSupport(MindInsightException):
@@ -82,21 +71,20 @@ class CodeSyntaxError(MindInsightException):

class MindConverterException(Exception):
"""MindConverter exception."""
BASE_ERROR_CODE = None # ConverterErrors.BASE_CONVERTER_FAIL.value
# ERROR_CODE should be declared in child exception.
ERROR_CODE = None

def __init__(self, **kwargs):
"""Initialization of MindInsightException."""
error = kwargs.get('error', None)
user_msg = kwargs.get('user_msg', '')
debug_msg = kwargs.get('debug_msg', '')
cls_code = kwargs.get('cls_code', 0)

if isinstance(user_msg, str):
user_msg = ' '.join(user_msg.split())

super(MindConverterException, self).__init__()
self.error = error
self.user_msg = user_msg
self.debug_msg = debug_msg
self.cls_code = cls_code
self.root_exception_error_code = None

def __str__(self):
return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg)
@@ -116,8 +104,12 @@ class MindConverterException(Exception):
Returns:
str, Hex string representing the composed MindConverter error code.
"""
num = 0xFFFF & self.error.value
error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper()))
if self.root_exception_error_code:
return self.root_exception_error_code
if self.BASE_ERROR_CODE is None or self.ERROR_CODE is None:
raise ValueError("MindConverterException has not been initialized.")
num = 0xFFFF & self.ERROR_CODE # 0xFFFF & self.error.value
error_code = f"{str(self.BASE_ERROR_CODE).zfill(3)}{hex(num)[2:].zfill(4).upper()}"
return error_code

@classmethod
@@ -126,7 +118,7 @@ class MindConverterException(Exception):
"""Raise from below exceptions."""

@classmethod
def uniform_catcher(cls, msg):
def uniform_catcher(cls, msg: str = ""):
"""Uniform exception catcher."""

def decorator(func):
@@ -134,16 +126,20 @@ class MindConverterException(Exception):
try:
res = func(*args, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
error = cls() if not msg else cls(msg=msg)
detail_info = str(e)
log.error(error)
log_console.error("\n")
log_console.error(detail_info)
log_console.error("\n")
log.exception(e)
sys.exit(0)
except ModuleNotFoundError as e:
detail_info = f"Error detail: Required package not found, please check the runtime environment."
detail_info = "Error detail: Required package not found, please check the runtime environment."
log_console.error("\n")
log_console.error(str(e))
log_console.error(detail_info)
log_console.error("\n")
log.exception(e)
sys.exit(0)
return res
@@ -161,9 +157,12 @@ class MindConverterException(Exception):
try:
output = func(*args, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
error_code = e.error_code() if isinstance(e, MindConverterException) else None
error.root_exception_error_code = error_code
log.error(msg)
log.exception(e)
raise cls(msg=msg)
raise error
except Exception as e:
log.error(msg)
log.exception(e)
@@ -175,12 +174,21 @@ class MindConverterException(Exception):
return decorator


class BaseConverterFail(MindConverterException):
class BaseConverterError(MindConverterException):
"""Base converter failed."""

def __init__(self, msg):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of BaseConverterError."""
UNKNOWN_ERROR = 0
UNKNOWN_MODEL = 1

BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Failed to start base converter."

def __init__(self, msg=DEFAULT_MSG):
super(BaseConverterError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
@@ -189,32 +197,45 @@ class BaseConverterFail(MindConverterException):
return except_source


class UnknownModel(MindConverterException):
class UnknownModelError(BaseConverterError):
"""The unknown model error."""
ERROR_CODE = BaseConverterError.ErrCode.UNKNOWN_MODEL.value

def __init__(self, msg):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg)
super(UnknownModelError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return cls


class GraphInitFail(MindConverterException):
class GraphInitError(MindConverterException):
"""The graph init fail error."""

def __init__(self, **kwargs):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', ''))
@unique
class ErrCode(Enum):
"""Define error code of GraphInitError."""
UNKNOWN_ERROR = 0
MODEL_NOT_SUPPORT = 1
TF_RUNTIME_ERROR = 2
INPUT_SHAPE_ERROR = 3
MI_RUNTIME_ERROR = 4

BASE_ERROR_CODE = ConverterErrors.GRAPH_INIT_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when init graph object."

def __init__(self, msg=DEFAULT_MSG):
super(GraphInitError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (FileNotFoundError,
ModuleNotFoundError,
ModelNotSupport,
SubGraphSearchingFail,
ModelNotSupportError,
ModelLoadingError,
RuntimeIntegrityError,
TypeError,
ZeroDivisionError,
RuntimeError,
@@ -222,45 +243,65 @@ class GraphInitFail(MindConverterException):
return except_source


class TreeCreateFail(MindConverterException):
class TreeCreationError(MindConverterException):
"""The tree create fail."""

def __init__(self, msg):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of TreeCreationError."""
UNKNOWN_ERROR = 0
NODE_INPUT_MISSING = 1
TREE_NODE_INSERT_FAIL = 2

BASE_ERROR_CODE = ConverterErrors.TREE_CREATE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when create hierarchical tree."

def __init__(self, msg=DEFAULT_MSG):
super(TreeCreationError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (NodeInputMissing,
TreeNodeInsertFail, cls)
except_source = NodeInputMissingError, TreeNodeInsertError, cls
return except_source


class SourceFilesSaveFail(MindConverterException):
class SourceFilesSaveError(MindConverterException):
"""The source files save fail error."""

def __init__(self, msg):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
UNKNOWN_ERROR = 0
NODE_INPUT_TYPE_NOT_SUPPORT = 1
SCRIPT_GENERATE_FAIL = 2
REPORT_GENERATE_FAIL = 3

BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
DEFAULT_MSG = "Error occurred when save source files."

def __init__(self, msg=DEFAULT_MSG):
super(SourceFilesSaveError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (NodeInputTypeNotSupport,
ScriptGenerateFail,
ReportGenerateFail,
except_source = (NodeInputTypeNotSupportError,
ScriptGenerationError,
ReportGenerationError,
IOError, cls)
return except_source


class ModelNotSupport(MindConverterException):
class ModelNotSupportError(GraphInitError):
"""The model not support error."""

ERROR_CODE = GraphInitError.ErrCode.MODEL_NOT_SUPPORT.value

def __init__(self, msg):
super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)
super(ModelNotSupportError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
@@ -275,13 +316,13 @@ class ModelNotSupport(MindConverterException):
return except_source


class TfRuntimeError(MindConverterException):
class TfRuntimeError(GraphInitError):
"""Catch tf runtime error."""
ERROR_CODE = GraphInitError.ErrCode.TF_RUNTIME_ERROR.value
DEFAULT_MSG = "Error occurred when init graph, TensorFlow runtime error."

def __init__(self, msg):
super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)
def __init__(self, msg=DEFAULT_MSG):
super(TfRuntimeError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
@@ -290,26 +331,36 @@ class TfRuntimeError(MindConverterException):
return tf_error, ValueError, RuntimeError, cls


class NodeInputMissing(MindConverterException):
class RuntimeIntegrityError(GraphInitError):
"""Catch runtime error."""
ERROR_CODE = GraphInitError.ErrCode.MI_RUNTIME_ERROR.value

def __init__(self, msg):
super(RuntimeIntegrityError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls


class NodeInputMissingError(TreeCreationError):
"""The node input missing error."""
ERROR_CODE = TreeCreationError.ErrCode.NODE_INPUT_MISSING.value

def __init__(self, msg):
super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING,
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)
super(NodeInputMissingError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return ValueError, IndexError, KeyError, AttributeError, cls


class TreeNodeInsertFail(MindConverterException):
class TreeNodeInsertError(TreeCreationError):
"""The tree node create fail error."""
ERROR_CODE = TreeCreationError.ErrCode.TREE_NODE_INSERT_FAIL.value

def __init__(self, msg):
super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL,
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)
super(TreeNodeInsertError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
@@ -321,26 +372,24 @@ class TreeNodeInsertFail(MindConverterException):
return except_source


class NodeInputTypeNotSupport(MindConverterException):
class NodeInputTypeNotSupportError(SourceFilesSaveError):
"""The node input type NOT support error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value

def __init__(self, msg):
super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(NodeInputTypeNotSupportError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return ValueError, TypeError, IndexError, cls


class ScriptGenerateFail(MindConverterException):
class ScriptGenerationError(SourceFilesSaveError):
"""The script generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value

def __init__(self, msg):
super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(ScriptGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
@@ -351,13 +400,12 @@ class ScriptGenerateFail(MindConverterException):
return except_source


class ReportGenerateFail(MindConverterException):
class ReportGenerationError(SourceFilesSaveError):
"""The report generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value

def __init__(self, msg):
super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL,
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)
super(ReportGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
@@ -365,13 +413,22 @@ class ReportGenerateFail(MindConverterException):
return ZeroDivisionError, cls


class SubGraphSearchingFail(MindConverterException):
class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception."""

def __init__(self, msg):
super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value,
user_msg=msg)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
BASE_ERROR = 0
CANNOT_FIND_VALID_PATTERN = 1
MODEL_NOT_SUPPORT = 2

BASE_ERROR_CODE = ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value
ERROR_CODE = ErrCode.BASE_ERROR.value
DEFAULT_MSG = "Sub-Graph pattern searching fail."

def __init__(self, msg=DEFAULT_MSG):
super(SubGraphSearchingError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
@@ -379,13 +436,22 @@ class SubGraphSearchingFail(MindConverterException):
return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls


class GeneratorFail(MindConverterException):
class GeneratorError(MindConverterException):
"""The Generator fail error."""

def __init__(self, msg):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value)
@unique
class ErrCode(Enum):
"""Define error code of SourceFilesSaveError."""
BASE_ERROR = 0
STATEMENT_GENERATION_ERROR = 1
CONVERTED_OPERATOR_LOADING_ERROR = 2

BASE_ERROR_CODE = ConverterErrors.GENERATOR_FAIL.value
ERROR_CODE = ErrCode.BASE_ERROR.value
DEFAULT_MSG = "Error occurred when generate code."

def __init__(self, msg=DEFAULT_MSG):
super(GeneratorError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
@@ -394,13 +460,12 @@ class GeneratorFail(MindConverterException):
return except_source


class ModelLoadingFail(MindConverterException):
class ModelLoadingError(GraphInitError):
"""Model loading fail."""
ERROR_CODE = GraphInitError.ErrCode.INPUT_SHAPE_ERROR.value

def __init__(self, msg):
super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR,
cls_code=ConverterErrors.MODEL_LOADING_FAIL.value,
user_msg=msg)
super(ModelLoadingError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -87,7 +87,7 @@ def get_imported_module():
str, imported module.
"""
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 31
- 31
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -27,8 +27,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEAD
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \
BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError
from mindinsight.utils.exceptions import ParamMissError

permissions = os.R_OK | os.W_OK | os.X_OK
@@ -67,13 +67,13 @@ def torch_installation_validation(func):
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
if not find_spec("torch"):
error = ModuleNotFoundError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
error = RuntimeIntegrityError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
@@ -97,17 +97,17 @@ def tf_installation_validation(func):
output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed.
if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \
or not find_spec("onnxruntime"):
error = ModuleNotFoundError(
if not find_spec("tensorflow") or not find_spec("tensorflow-gpu") or not find_spec("tf2onnx") \
or not find_spec("onnx") or not find_spec("onnxruntime"):
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"based scripts converter for TensorFlow conversion."
)
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0)

onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
@@ -116,15 +116,15 @@ def tf_installation_validation(func):
if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
error = ModuleNotFoundError(
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"based scripts converter for TensorFlow conversion."
)
log.error(str(error))
detail_info = f"Error detail: {str(error)}"
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error(detail_info)
log_console.error("\n")
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
@@ -150,14 +150,14 @@ def _extract_model_name(model_path):


@torch_installation_validation
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
@GraphInitError.uniform_catcher()
@TreeCreationError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
"""
Pytoch to MindSpore based on Graph.
PyTorch to MindSpore based on Graph.

Args:
graph_path (str): Graph file path.
@@ -185,11 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,


@tf_installation_validation
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
@GraphInitError.uniform_catcher()
@TfRuntimeError.uniform_catcher()
@TreeCreationError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
@@ -221,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)


@BaseConverterFail.uniform_catcher("Failed to start base converter.")
@BaseConverterError.uniform_catcher()
def main_graph_base_converter(file_config):
"""
The entrance for converter, script files will be converted.
@@ -248,7 +248,7 @@ def main_graph_base_converter(file_config):
report_folder=file_config['report_dir'])
else:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
error = UnknownModelError(error_msg)
log.error(str(error))
raise error

@@ -263,7 +263,7 @@ def get_framework_type(model_path):
framework_type = FrameworkType.TENSORFLOW.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
error = UnknownModelError(error_msg)
log.error(str(error))
raise error



+ 4
- 4
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -23,7 +23,7 @@ from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorFail
from ...common.exceptions import GeneratorError
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator
@@ -171,7 +171,7 @@ class Generator:
self._global_context.node_struct_collections = self._node_struct_collections
self._repeated_submodules = set()

@GeneratorFail.check_except("Generator occurs an error when forming base submodules.")
@GeneratorError.check_except("Generator occurs an error when forming base submodules.")
def _form_bottom_submodule(self):
"""Form the basic submodules, which only contains nodes."""
# Form module map
@@ -351,7 +351,7 @@ class Generator:
self._global_context.add_module_struct(sub.pattern_id, sub)
depth -= 1

@GeneratorFail.check_except("Generator occurs an error when building modules.")
@GeneratorError.check_except("Generator occurs an error when building modules.")
def _recursive_form_module(self):
"""Main routine in generator to build modules from bottom to top."""
# 1. List repeated submodules
@@ -474,7 +474,7 @@ class Generator:
"""Return all ModuleStructs in this model."""
return self._module_struct_collections

@GeneratorFail.check_except("Generator occurs an error when generating code statements.")
@GeneratorError.check_except("Generator occurs an error when generating code statements.")
def generate(self):
"""
Generate the final script file.


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -22,7 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ..constant import InputType
from ...common.exceptions import GeneratorFail
from ...common.exceptions import GeneratorError


class NodeStruct:
@@ -146,7 +146,7 @@ class NodeStruct:
parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier)
self.scope = Scope(parsed_scope)

@GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.")
@GeneratorError.check_except("Generator occurs an error when initializing node's args translator.")
def init_args_translator(self, translated_args: list):
"""
Initialize the ArgsTranslator for each Node.
@@ -170,7 +170,7 @@ class NodeStruct:
self.ms_op]):
self.ready_to_generate = True

@GeneratorFail.check_except("Generator occurs an error when creating node struct.")
@GeneratorError.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False):
"""
Pass Node info. to generator NodeStruct.


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode

from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail
from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError


def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
@@ -53,7 +53,7 @@ class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@classmethod
@TreeNodeInsertFail.check_except("Tree node inserts failed.")
@TreeNodeInsertError.check_except("Tree node inserts failed.")
def create(cls, graph):
"""
Factory method of hierarchical tree.
@@ -73,7 +73,7 @@ class HierarchicalTreeFactory:
if node_input != 0 and not node_input:
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
error = NodeInputMissing(err_msg)
error = NodeInputMissingError(err_msg)
log.error(str(error))
raise error
if isinstance(node_inst, OnnxGraphNode):


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -35,7 +35,7 @@ from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
from ..report_generator import ReportGenerator
from ...common.exceptions import ReportGenerateFail, ScriptGenerateFail, NodeInputTypeNotSupport
from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError


class HierarchicalTree(Tree):
@@ -189,7 +189,7 @@ class HierarchicalTree(Tree):
try:
self._adjust_structure()
code_fragments = self._generate_codes(mapper)
except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e:
except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e:
log.error("Error occur when generating codes.")
raise e

@@ -264,8 +264,8 @@ class HierarchicalTree(Tree):
node.data.args_in_code.pop(arg)
return node

@ScriptGenerateFail.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerateFail.check_except("Not find valid operators in converted script.")
@ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerationError.check_except("Not find valid operators in converted script.")
def _generate_codes(self, mapper):
"""
Generate code files.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -21,7 +21,7 @@ from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingFail
from ...common.exceptions import SubGraphSearchingError


def _is_satisfied(path):
@@ -267,7 +267,7 @@ def _add_known_module_name(search_path):
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name


@SubGraphSearchingFail.check_except("Sub-Graph searching fail.")
@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
def generate_scope_name(data_loader):
"""
Generate scope name according to computation graph.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -21,7 +21,7 @@ from mindinsight.mindconverter.common.log import logger as log
from ..common.code_fragment import CodeFragment
from ..constant import NodeType, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport
from ...common.exceptions import NodeInputTypeNotSupportError


class GraphParser(metaclass=abc.ABCMeta):
@@ -522,7 +522,7 @@ class GraphNode(abc.ABC):
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct},)"
else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
raise NodeInputTypeNotSupportError(f"Input type[{input_type}] is not supported now.")
else:
ipt_args_settings_in_construct = ipt_args_in_construct



+ 4
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail
from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -308,7 +308,7 @@ class OnnxDataLoader:
w = int(match.group('w'))
c = int(match.group('c'))
if [h, w, c] != list(self.graph_input_shape)[1:4]:
raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
raise ModelLoadingError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
return True
return False

@@ -383,7 +383,7 @@ class OnnxDataLoader:
self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1:
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.")
raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name
self._global_context.onnx_node_name_to_topo_idx[n.name] = idx
node_inputs = [i.replace(":0", "") for i in node.input]
@@ -423,7 +423,7 @@ class OnnxDataLoader:
shape[i] = int(shape[i])
node_name = self.output_name_to_node_name[node_opt_name]
if not node_name:
raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.")
raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.")
self.node_output_shape_dict[node_name] = shape

def get_node(self, node_name):


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -18,14 +18,14 @@ from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupport
from ...common.exceptions import ModelNotSupportError


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except(
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -18,14 +18,14 @@ from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupport
from ...common.exceptions import ModelNotSupportError


class TFGraphParser(GraphParser):
"""Define TF graph parser."""

@classmethod
@ModelNotSupport.check_except(
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):


Loading…
Cancel
Save