From: @moran3 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -15,6 +15,7 @@ | |||||
| """Define custom exception.""" | """Define custom exception.""" | ||||
| import sys | import sys | ||||
| from enum import unique | from enum import unique | ||||
| from importlib import import_module | |||||
| from lib2to3.pgen2 import parse | from lib2to3.pgen2 import parse | ||||
| from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError | from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError | ||||
| @@ -212,7 +213,8 @@ class GraphInitFail(MindConverterException): | |||||
| ModuleNotFoundError, | ModuleNotFoundError, | ||||
| ModelNotSupport, | ModelNotSupport, | ||||
| TypeError, | TypeError, | ||||
| ZeroDivisionError) | |||||
| ZeroDivisionError, | |||||
| RuntimeError) | |||||
| return except_source | return except_source | ||||
| @classmethod | @classmethod | ||||
| @@ -294,7 +296,7 @@ class ModelNotSupport(MindConverterException): | |||||
| return except_source | return except_source | ||||
| @classmethod | @classmethod | ||||
| def check_except(cls, msg): | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | """Check except.""" | ||||
| def decorator(func): | def decorator(func): | ||||
| @@ -310,6 +312,29 @@ class ModelNotSupport(MindConverterException): | |||||
| return _f | return _f | ||||
| return decorator | return decorator | ||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except.""" | |||||
| tf_error_module = import_module('tensorflow.python.framework.errors_impl') | |||||
| tf_error = getattr(tf_error_module, 'OpError') | |||||
| cls._error = cls.raise_from() + (tf_error,) | |||||
| def decorator(func): | |||||
| def _f(arch, model_path, **kwargs): | |||||
| try: | |||||
| output = func(arch, model_path=model_path, **kwargs) | |||||
| except cls._error as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class NodeInputMissing(MindConverterException): | class NodeInputMissing(MindConverterException): | ||||
| """The node input missing error.""" | """The node input missing error.""" | ||||
| @@ -119,10 +119,10 @@ def _extract_model_name(model_path): | |||||
| return model_name | return model_name | ||||
| @torch_installation_validation | |||||
| @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | ||||
| @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | ||||
| @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | ||||
| @torch_installation_validation | |||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| """ | """ | ||||
| @@ -153,10 +153,10 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| @tf_installation_validation | |||||
| @GraphInitFail.check_except_tf("Error occurred when init graph object.") | @GraphInitFail.check_except_tf("Error occurred when init graph object.") | ||||
| @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | ||||
| @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | ||||
| @tf_installation_validation | |||||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | input_nodes: str, output_nodes: str, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| @@ -14,13 +14,10 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Graph associated definition module.""" | """Graph associated definition module.""" | ||||
| __all__ = ["GraphFactory", "PyTorchGraphNode"] | |||||
| __all__ = ["GraphFactory"] | |||||
| from importlib import import_module | |||||
| from .base import Graph | from .base import Graph | ||||
| from .pytorch_graph import PyTorchGraph | |||||
| from .pytorch_graph_node import PyTorchGraphNode | |||||
| from .onnx_graph import OnnxGraph | |||||
| from .onnx_graph_node import OnnxGraphNode | |||||
| class GraphFactory: | class GraphFactory: | ||||
| @@ -43,7 +40,13 @@ class GraphFactory: | |||||
| Graph, graph instance. | Graph, graph instance. | ||||
| """ | """ | ||||
| if all([input_nodes, output_nodes]): | if all([input_nodes, output_nodes]): | ||||
| return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes, | |||||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||||
| return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | |||||
| onnx_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||||
| pytorch_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') | |||||
| pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') | |||||
| return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) | |||||
| @@ -19,7 +19,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import Graph | from .base import Graph | ||||
| from .input_node import InputNode | from .input_node import InputNode | ||||
| from .onnx_graph_node import OnnxGraphNode | from .onnx_graph_node import OnnxGraphNode | ||||
| from .graph_parser import TFGraphParser | |||||
| from .tf_graph_parser import TFGraphParser | |||||
| from .onnx_utils import OnnxDataLoader | from .onnx_utils import OnnxDataLoader | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| @@ -20,9 +20,9 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import Graph | from .base import Graph | ||||
| from .input_node import InputNode | from .input_node import InputNode | ||||
| from .pytorch_graph_node import PyTorchGraphNode | from .pytorch_graph_node import PyTorchGraphNode | ||||
| from .graph_parser import PyTorchGraphParser | |||||
| from .pytorch_graph_parser import PyTorchGraphParser | |||||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID | |||||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | from ..constant import LEFT_BUCKET, RIGHT_BUCKET | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| @@ -46,7 +46,7 @@ def normalize_scope_name(node): | |||||
| """ | """ | ||||
| global NONE_SCOPE_OP | global NONE_SCOPE_OP | ||||
| name = node.scopeName().split(SEPARATOR_IN_SCOPE) | |||||
| name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||||
| scopes = [] | scopes = [] | ||||
| for segment in name: | for segment in name: | ||||
| segment = segment.split(LINK_IN_SCOPE)[0] | segment = segment.split(LINK_IN_SCOPE)[0] | ||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Third party graph parser.""" | |||||
| import os | |||||
| from importlib import import_module | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import GraphParser | |||||
| from ...common.exceptions import ModelNotSupport | |||||
| class PyTorchGraphParser(GraphParser): | |||||
| """Define pytorch graph parser.""" | |||||
| @classmethod | |||||
| @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") | |||||
| def parse(cls, model_path: str, **kwargs): | |||||
| """ | |||||
| Parser pytorch graph. | |||||
| Args: | |||||
| model_path (str): Model file path. | |||||
| Returns: | |||||
| object, torch model. | |||||
| """ | |||||
| torch = import_module("torch") | |||||
| if not os.path.exists(model_path): | |||||
| error = FileNotFoundError("`model_path` must be assigned with " | |||||
| "an existed file path.") | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| try: | |||||
| if torch.cuda.is_available(): | |||||
| model = torch.load(f=model_path) | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| except ModuleNotFoundError: | |||||
| error_msg = \ | |||||
| "Cannot find model scripts in system path, " \ | |||||
| "set `--project_path` to the path of model scripts folder correctly." | |||||
| error = ModuleNotFoundError(error_msg) | |||||
| log.error(str(error)) | |||||
| raise error from None | |||||
| return model | |||||
| @@ -14,55 +14,18 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Third party graph parser.""" | """Third party graph parser.""" | ||||
| import os | import os | ||||
| from importlib import import_module | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | from .base import GraphParser | ||||
| from ...common.exceptions import ModelNotSupport | from ...common.exceptions import ModelNotSupport | ||||
| class PyTorchGraphParser(GraphParser): | |||||
| """Define pytorch graph parser.""" | |||||
| @classmethod | |||||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.") | |||||
| def parse(cls, model_path: str, **kwargs): | |||||
| """ | |||||
| Parser pytorch graph. | |||||
| Args: | |||||
| model_path (str): Model file path. | |||||
| Returns: | |||||
| object, torch model. | |||||
| """ | |||||
| import torch | |||||
| if not os.path.exists(model_path): | |||||
| error = FileNotFoundError("`model_path` must be assigned with " | |||||
| "an existed file path.") | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| try: | |||||
| if torch.cuda.is_available(): | |||||
| model = torch.load(f=model_path) | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| except ModuleNotFoundError: | |||||
| error_msg = \ | |||||
| "Cannot find model scripts in system path, " \ | |||||
| "set `--project_path` to the path of model scripts folder correctly." | |||||
| error = ModuleNotFoundError(error_msg) | |||||
| log.error(str(error)) | |||||
| raise error from None | |||||
| return model | |||||
| class TFGraphParser(GraphParser): | class TFGraphParser(GraphParser): | ||||
| """Define TF graph parser.""" | """Define TF graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.") | |||||
| @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") | |||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| Parse TF Computational Graph File (.pb) | Parse TF Computational Graph File (.pb) | ||||
| @@ -74,8 +37,9 @@ class TFGraphParser(GraphParser): | |||||
| object, ONNX model. | object, ONNX model. | ||||
| """ | """ | ||||
| from .onnx_utils import convert_tf_graph_to_onnx | |||||
| onnx_utils = import_module( | |||||
| "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") | |||||
| convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") | |||||
| tf_input_nodes = kwargs.get('input_nodes') | tf_input_nodes = kwargs.get('input_nodes') | ||||
| tf_output_nodes = kwargs.get('output_nodes') | tf_output_nodes = kwargs.get('output_nodes') | ||||
| if not os.path.exists(model_path): | if not os.path.exists(model_path): | ||||