| @@ -13,10 +13,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Define custom exception.""" | """Define custom exception.""" | ||||
| import sys | |||||
| from enum import unique | from enum import unique | ||||
| from lib2to3.pgen2 import parse | |||||
| from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError | |||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||||
| from mindinsight.utils.constant import ScriptConverterErrors | from mindinsight.utils.constant import ScriptConverterErrors | ||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| from mindinsight.utils.exceptions import MindInsightException, ParamMissError | |||||
| @unique | @unique | ||||
| @@ -26,7 +32,17 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| NODE_TYPE_NOT_SUPPORT = 2 | NODE_TYPE_NOT_SUPPORT = 2 | ||||
| CODE_SYNTAX_ERROR = 3 | CODE_SYNTAX_ERROR = 3 | ||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | NODE_INPUT_TYPE_NOT_SUPPORT = 4 | ||||
| UNKNOWN_MODEL = 5 | |||||
| NODE_INPUT_MISSING = 5 | |||||
| TREE_NODE_INSERT_FAIL = 6 | |||||
| UNKNOWN_MODEL = 7 | |||||
| MODEL_NOT_SUPPORT = 8 | |||||
| SCRIPT_GENERATE_FAIL = 9 | |||||
| REPORT_GENERATE_FAIL = 10 | |||||
| BASE_CONVERTER_FAIL = 000 | |||||
| GRAPH_INIT_FAIL = 100 | |||||
| TREE_CREATE_FAIL = 200 | |||||
| SOURCE_FILES_SAVE_FAIL = 300 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -56,19 +72,354 @@ class CodeSyntaxError(MindInsightException): | |||||
| http_code=400) | http_code=400) | ||||
| class NodeInputTypeNotSupport(MindInsightException): | |||||
| """The node input type NOT support error.""" | |||||
| class MindConverterException(Exception): | |||||
| """MindConverter exception.""" | |||||
| def __init__(self, **kwargs): | |||||
| """Initialization of MindInsightException.""" | |||||
| error = kwargs.get('error', None) | |||||
| user_msg = kwargs.get('user_msg', '') | |||||
| debug_msg = kwargs.get('debug_msg', '') | |||||
| cls_code = kwargs.get('cls_code', 0) | |||||
| if isinstance(user_msg, str): | |||||
| user_msg = ' '.join(user_msg.split()) | |||||
| super(MindConverterException, self).__init__() | |||||
| self.error = error | |||||
| self.user_msg = user_msg | |||||
| self.debug_msg = debug_msg | |||||
| self.cls_code = cls_code | |||||
| def __str__(self): | |||||
| return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | |||||
| def error_code(self): | |||||
| """" | |||||
| Calculate error code. | |||||
| code compose(2bytes) | |||||
| error: 16bits. | |||||
| num = 0xFFFF & error | |||||
| error_cods | |||||
| Returns: | |||||
| str, Hex string representing the composed MindConverter error code. | |||||
| """ | |||||
| num = 0xFFFF & self.error.value | |||||
| error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | |||||
| return error_code | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from below exceptions.""" | |||||
| return None | |||||
| @classmethod | |||||
| def check_except_with_print_pytorch(cls, msg): | |||||
| """Check except in pytorch.""" | |||||
| def decorator(func): | |||||
| def _f(graph_path, sample_shape, output_folder, report_folder): | |||||
| try: | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| output_folder=output_folder, report_folder=report_folder) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| log.exception(e) | |||||
| sys.exit(-1) | |||||
| return _f | |||||
| return decorator | |||||
| @classmethod | |||||
| def check_except_with_print_tf(cls, msg): | |||||
| """Check except in tf.""" | |||||
| def decorator(func): | |||||
| def _f(graph_path, sample_shape, | |||||
| input_nodes, output_nodes, | |||||
| output_folder, report_folder): | |||||
| try: | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes, | |||||
| output_folder=output_folder, report_folder=report_folder) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| log.exception(e) | |||||
| sys.exit(-1) | |||||
| return _f | |||||
| return decorator | |||||
| class BaseConverterFail(MindConverterException): | |||||
| """Base converter failed.""" | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||||
| msg, | |||||
| http_code=400) | |||||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | |||||
| user_msg=msg) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (UnknownModel, | |||||
| ParamMissError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(file_config): | |||||
| try: | |||||
| func(file_config=file_config) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| log.exception(e) | |||||
| sys.exit(-1) | |||||
| return _f | |||||
| return decorator | |||||
| class UnknownModel(MindInsightException): | |||||
| class UnknownModel(MindConverterException): | |||||
| """The unknown model error.""" | """The unknown model error.""" | ||||
| def __init__(self, msg): | |||||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | |||||
| user_msg=msg) | |||||
| class GraphInitFail(MindConverterException): | |||||
| """The graph init fail error.""" | |||||
| def __init__(self, **kwargs): | |||||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | |||||
| user_msg=kwargs.get('msg', '')) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (FileNotFoundError, | |||||
| ModuleNotFoundError, | |||||
| ModelNotSupport, | |||||
| TypeError, | |||||
| ZeroDivisionError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except for pytorch.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class TreeCreateFail(MindConverterException): | |||||
| """The tree create fail.""" | |||||
| def __init__(self, msg): | |||||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | |||||
| user_msg=msg) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (NodeInputMissing, | |||||
| TreeNodeInsertFail) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class SourceFilesSaveFail(MindConverterException): | |||||
| """The source files save fail error.""" | |||||
| def __init__(self, msg): | |||||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | |||||
| user_msg=msg) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (NodeInputTypeNotSupport, | |||||
| ScriptGenerateFail, | |||||
| ReportGenerateFail, | |||||
| IOError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class ModelNotSupport(MindConverterException): | |||||
| """The model not support error.""" | |||||
| def __init__(self, msg): | |||||
| super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (RuntimeError, | |||||
| ValueError, | |||||
| TypeError, | |||||
| OSError, | |||||
| ZeroDivisionError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, model_path, **kwargs): | |||||
| try: | |||||
| output = func(arch, model_path=model_path, **kwargs) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class NodeInputMissing(MindConverterException): | |||||
| """The node input missing error.""" | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL, | |||||
| msg, | |||||
| http_code=400) | |||||
| super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||||
| class TreeNodeInsertFail(MindConverterException): | |||||
| """The tree node create fail error.""" | |||||
| def __init__(self, msg): | |||||
| super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (OSError, | |||||
| DuplicatedNodeIdError, | |||||
| MultipleRootError, | |||||
| NodeIDAbsentError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, graph): | |||||
| try: | |||||
| output = func(arch, graph=graph) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class NodeInputTypeNotSupport(MindConverterException): | |||||
| """The node input type NOT support error.""" | |||||
| def __init__(self, msg): | |||||
| super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| class ScriptGenerateFail(MindConverterException): | |||||
| """The script generate fail error.""" | |||||
| def __init__(self, msg): | |||||
| super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (RuntimeError, | |||||
| parse.ParseError, | |||||
| AttributeError) | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, mapper): | |||||
| try: | |||||
| output = func(arch, mapper=mapper) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class ReportGenerateFail(MindConverterException): | |||||
| """The report generate fail error.""" | |||||
| def __init__(self, msg): | |||||
| super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | |||||
| @staticmethod | |||||
| def raise_from(): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = ZeroDivisionError | |||||
| return except_source | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, mapper): | |||||
| try: | |||||
| output = func(arch, mapper=mapper) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| @@ -16,3 +16,5 @@ | |||||
| from mindinsight.utils.log import setup_logger | from mindinsight.utils.log import setup_logger | ||||
| logger = setup_logger("mindconverter", "mindconverter", console=False) | logger = setup_logger("mindconverter", "mindconverter", console=False) | ||||
| logger_console = setup_logger("mindconverter", "mindconverter", console=True, | |||||
| sub_log_name="logger_console", formatter="%(message)s") | |||||
| @@ -20,11 +20,12 @@ from importlib import import_module | |||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | ||||
| BINARY_HEADER_PYTORCH_BITS | BINARY_HEADER_PYTORCH_BITS | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | |||||
| BaseConverterFail, UnknownModel | |||||
| from mindinsight.utils.exceptions import ParamMissError | from mindinsight.utils.exceptions import ParamMissError | ||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| @@ -118,6 +119,9 @@ def _extract_model_name(model_path): | |||||
| return model_name | return model_name | ||||
| @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | |||||
| @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | |||||
| @torch_installation_validation | @torch_installation_validation | ||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| @@ -139,12 +143,8 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | ||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | ||||
| try: | |||||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| @@ -153,6 +153,9 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| @GraphInitFail.check_except_tf("Error occurred when init graph object.") | |||||
| @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | |||||
| @tf_installation_validation | @tf_installation_validation | ||||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | input_nodes: str, output_nodes: str, | ||||
| @@ -181,28 +184,24 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | ||||
| input_nodes=input_nodes, output_nodes=output_nodes) | input_nodes=input_nodes, output_nodes=output_nodes) | ||||
| try: | |||||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | ||||
| model_name=model_name, | model_name=model_name, | ||||
| report_folder=report_folder, | report_folder=report_folder, | ||||
| scope_name_map=scope_name_map) | scope_name_map=scope_name_map) | ||||
| @BaseConverterFail.check_except("Failed to start base converter.") | |||||
| def main_graph_base_converter(file_config): | def main_graph_base_converter(file_config): | ||||
| """ | """ | ||||
| The entrance for converter, script files will be converted. | |||||
| The entrance for converter, script files will be converted. | |||||
| Args: | |||||
| file_config (dict): The config of file which to convert. | |||||
| """ | |||||
| Args: | |||||
| file_config (dict): The config of file which to convert. | |||||
| """ | |||||
| graph_path = file_config['model_file'] | graph_path = file_config['model_file'] | ||||
| frame_type = get_framework_type(graph_path) | frame_type = get_framework_type(graph_path) | ||||
| if frame_type == FrameworkType.PYTORCH.value: | if frame_type == FrameworkType.PYTORCH.value: | ||||
| @@ -223,7 +222,6 @@ def main_graph_base_converter(file_config): | |||||
| error_msg = "Get UNSUPPORTED model." | error_msg = "Get UNSUPPORTED model." | ||||
| error = UnknownModel(error_msg) | error = UnknownModel(error_msg) | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | raise error | ||||
| @@ -239,7 +237,6 @@ def get_framework_type(model_path): | |||||
| error_msg = "Get UNSUPPORTED model." | error_msg = "Get UNSUPPORTED model." | ||||
| error = UnknownModel(error_msg) | error = UnknownModel(error_msg) | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | raise error | ||||
| return framework_type | return framework_type | ||||
| @@ -253,4 +250,6 @@ def check_params_exist(params: list, config): | |||||
| miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param | miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param | ||||
| if miss_param_list: | if miss_param_list: | ||||
| raise ParamMissError(miss_param_list) | |||||
| error = ParamMissError(miss_param_list) | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| @@ -23,6 +23,8 @@ __all__ = [ | |||||
| "HierarchicalTreeFactory" | "HierarchicalTreeFactory" | ||||
| ] | ] | ||||
| from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail | |||||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | ||||
| """ | """ | ||||
| @@ -53,6 +55,7 @@ class HierarchicalTreeFactory: | |||||
| """Hierarchical tree factory.""" | """Hierarchical tree factory.""" | ||||
| @classmethod | @classmethod | ||||
| @TreeNodeInsertFail.check_except("Tree node inserts failed.") | |||||
| def create(cls, graph): | def create(cls, graph): | ||||
| """ | """ | ||||
| Factory method of hierarchical tree. | Factory method of hierarchical tree. | ||||
| @@ -72,7 +75,9 @@ class HierarchicalTreeFactory: | |||||
| if not node_input: | if not node_input: | ||||
| err_msg = f"This model is not supported now. " \ | err_msg = f"This model is not supported now. " \ | ||||
| f"Cannot find {node_name}'s input shape." | f"Cannot find {node_name}'s input shape." | ||||
| log.error(err_msg) | |||||
| error = NodeInputMissing(err_msg) | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| if isinstance(node_inst, OnnxGraphNode): | if isinstance(node_inst, OnnxGraphNode): | ||||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | ||||
| node_scope_name[node_name] = node_name_with_scope | node_scope_name[node_name] = node_name_with_scope | ||||
| @@ -35,7 +35,7 @@ from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| from ..constant import NodeType | from ..constant import NodeType | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| from ...common.exceptions import NodeTypeNotSupport | |||||
| from ...common.exceptions import ReportGenerateFail, ScriptGenerateFail, NodeInputTypeNotSupport | |||||
| class HierarchicalTree(Tree): | class HierarchicalTree(Tree): | ||||
| @@ -189,10 +189,9 @@ class HierarchicalTree(Tree): | |||||
| try: | try: | ||||
| self._adjust_structure() | self._adjust_structure() | ||||
| code_fragments = self._generate_codes(mapper) | code_fragments = self._generate_codes(mapper) | ||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e: | |||||
| log.error("Error occur when generating codes.") | |||||
| raise e | |||||
| out_folder = os.path.realpath(out_folder) | out_folder = os.path.realpath(out_folder) | ||||
| if not report_folder: | if not report_folder: | ||||
| @@ -295,6 +294,8 @@ class HierarchicalTree(Tree): | |||||
| node.data.args_in_code.pop(arg) | node.data.args_in_code.pop(arg) | ||||
| return node | return node | ||||
| @ScriptGenerateFail.check_except("FormatCode run error. Check detailed information in log.") | |||||
| @ReportGenerateFail.check_except("Not find valid operators in converted script.") | |||||
| def _generate_codes(self, mapper): | def _generate_codes(self, mapper): | ||||
| """ | """ | ||||
| Generate code files. | Generate code files. | ||||
| @@ -382,6 +383,7 @@ class HierarchicalTree(Tree): | |||||
| formatted_code, _ = FormatCode("".join(code_blocks), | formatted_code, _ = FormatCode("".join(code_blocks), | ||||
| style_config=CodeFormatConfig.PEP8.value) | style_config=CodeFormatConfig.PEP8.value) | ||||
| report_generator = ReportGenerator() | report_generator = ReportGenerator() | ||||
| report = report_generator.gen_report(formatted_code) | report = report_generator.gen_report(formatted_code) | ||||
| @@ -16,12 +16,14 @@ | |||||
| import os | import os | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | from .base import GraphParser | ||||
| from ...common.exceptions import ModelNotSupport | |||||
| class PyTorchGraphParser(GraphParser): | class PyTorchGraphParser(GraphParser): | ||||
| """Define pytorch graph parser.""" | """Define pytorch graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.") | |||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| Parser pytorch graph. | Parser pytorch graph. | ||||
| @@ -51,13 +53,7 @@ class PyTorchGraphParser(GraphParser): | |||||
| "set `--project_path` to the path of model scripts folder correctly." | "set `--project_path` to the path of model scripts folder correctly." | ||||
| error = ModuleNotFoundError(error_msg) | error = ModuleNotFoundError(error_msg) | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | |||||
| except Exception as e: | |||||
| error_msg = "Error occurs in loading model, make sure model.pth correct." | |||||
| log.error(error_msg) | |||||
| log.exception(e) | |||||
| raise Exception(error_msg) | |||||
| raise error from None | |||||
| return model | return model | ||||
| @@ -66,6 +62,7 @@ class TFGraphParser(GraphParser): | |||||
| """Define TF graph parser.""" | """Define TF graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.") | |||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| Parse TF Computational Graph File (.pb) | Parse TF Computational Graph File (.pb) | ||||
| @@ -85,7 +82,6 @@ class TFGraphParser(GraphParser): | |||||
| error = FileNotFoundError("`model_path` must be assigned with " | error = FileNotFoundError("`model_path` must be assigned with " | ||||
| "an existed file path.") | "an existed file path.") | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | raise error | ||||
| try: | try: | ||||
| @@ -99,13 +95,7 @@ class TFGraphParser(GraphParser): | |||||
| "Cannot find model scripts in system path, " \ | "Cannot find model scripts in system path, " \ | ||||
| "set `--project_path` to the path of model scripts folder correctly." | "set `--project_path` to the path of model scripts folder correctly." | ||||
| error = ModuleNotFoundError(error_msg) | error = ModuleNotFoundError(error_msg) | ||||
| log.error(error_msg) | |||||
| log.exception(error) | |||||
| raise error | |||||
| except Exception as e: | |||||
| error_msg = "Error occurs in loading model, make sure model.pb correct." | |||||
| log.error(error_msg) | |||||
| log.exception(e) | |||||
| raise Exception(error_msg) | |||||
| log.error(str(error)) | |||||
| raise error from None | |||||
| return model | return model | ||||
| @@ -175,7 +175,10 @@ def setup_logger(sub_module, log_name, **kwargs): | |||||
| >>> logger = logging.getLogger('datavisual.flask.request') | >>> logger = logging.getLogger('datavisual.flask.request') | ||||
| """ | """ | ||||
| logger = get_logger(sub_module, log_name) | |||||
| if kwargs.get('sub_log_name', False): | |||||
| logger = get_logger(sub_module, kwargs['sub_log_name']) | |||||
| else: | |||||
| logger = get_logger(sub_module, log_name) | |||||
| if logger.hasHandlers(): | if logger.hasHandlers(): | ||||
| return logger | return logger | ||||