| @@ -15,6 +15,7 @@ | |||
| """Define custom exception.""" | |||
| import sys | |||
| from enum import unique | |||
| from importlib import import_module | |||
| from lib2to3.pgen2 import parse | |||
| from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError | |||
| @@ -212,7 +213,8 @@ class GraphInitFail(MindConverterException): | |||
| ModuleNotFoundError, | |||
| ModelNotSupport, | |||
| TypeError, | |||
| ZeroDivisionError) | |||
| ZeroDivisionError, | |||
| RuntimeError) | |||
| return except_source | |||
| @classmethod | |||
| @@ -294,7 +296,7 @@ class ModelNotSupport(MindConverterException): | |||
| return except_source | |||
| @classmethod | |||
| def check_except(cls, msg): | |||
| def check_except_pytorch(cls, msg): | |||
| """Check except.""" | |||
| def decorator(func): | |||
| @@ -310,6 +312,29 @@ class ModelNotSupport(MindConverterException): | |||
| return _f | |||
| return decorator | |||
| @classmethod | |||
| def check_except_tf(cls, msg): | |||
| """Check except.""" | |||
| tf_error_module = import_module('tensorflow.python.framework.errors_impl') | |||
| tf_error = getattr(tf_error_module, 'OpError') | |||
| cls._error = cls.raise_from() + (tf_error,) | |||
| def decorator(func): | |||
| def _f(arch, model_path, **kwargs): | |||
| try: | |||
| output = func(arch, model_path=model_path, **kwargs) | |||
| except cls._error as e: | |||
| error = cls(msg=msg) | |||
| log.error(msg) | |||
| log.exception(e) | |||
| raise error from e | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class NodeInputMissing(MindConverterException): | |||
| """The node input missing error.""" | |||
| @@ -119,10 +119,10 @@ def _extract_model_name(model_path): | |||
| return model_name | |||
| @torch_installation_validation | |||
| @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | |||
| @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | |||
| @torch_installation_validation | |||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| @@ -153,10 +153,10 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| report_folder=report_folder) | |||
| @tf_installation_validation | |||
| @GraphInitFail.check_except_tf("Error occurred when init graph object.") | |||
| @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | |||
| @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | |||
| @tf_installation_validation | |||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| @@ -14,13 +14,10 @@ | |||
| # ============================================================================== | |||
| """Graph associated definition module.""" | |||
| __all__ = ["GraphFactory", "PyTorchGraphNode"] | |||
| __all__ = ["GraphFactory"] | |||
| from importlib import import_module | |||
| from .base import Graph | |||
| from .pytorch_graph import PyTorchGraph | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .onnx_graph import OnnxGraph | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| class GraphFactory: | |||
| @@ -43,7 +40,13 @@ class GraphFactory: | |||
| Graph, graph instance. | |||
| """ | |||
| if all([input_nodes, output_nodes]): | |||
| return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes, | |||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||
| return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | |||
| onnx_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||
| pytorch_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') | |||
| pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') | |||
| return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) | |||
| @@ -19,7 +19,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| from .graph_parser import TFGraphParser | |||
| from .tf_graph_parser import TFGraphParser | |||
| from .onnx_utils import OnnxDataLoader | |||
| NONE_SCOPE_OP = { | |||
| @@ -20,9 +20,9 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .graph_parser import PyTorchGraphParser | |||
| from .pytorch_graph_parser import PyTorchGraphParser | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| NONE_SCOPE_OP = { | |||
| @@ -46,7 +46,7 @@ def normalize_scope_name(node): | |||
| """ | |||
| global NONE_SCOPE_OP | |||
| name = node.scopeName().split(SEPARATOR_IN_SCOPE) | |||
| name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||
| scopes = [] | |||
| for segment in name: | |||
| segment = segment.split(LINK_IN_SCOPE)[0] | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Third party graph parser.""" | |||
| import os | |||
| from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from ...common.exceptions import ModelNotSupport | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parser pytorch graph. | |||
| Args: | |||
| model_path (str): Model file path. | |||
| Returns: | |||
| object, torch model. | |||
| """ | |||
| torch = import_module("torch") | |||
| if not os.path.exists(model_path): | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| log.error(str(error)) | |||
| raise error | |||
| try: | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| except ModuleNotFoundError: | |||
| error_msg = \ | |||
| "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| log.error(str(error)) | |||
| raise error from None | |||
| return model | |||
| @@ -14,55 +14,18 @@ | |||
| # ============================================================================== | |||
| """Third party graph parser.""" | |||
| import os | |||
| from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from ...common.exceptions import ModelNotSupport | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.") | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parser pytorch graph. | |||
| Args: | |||
| model_path (str): Model file path. | |||
| Returns: | |||
| object, torch model. | |||
| """ | |||
| import torch | |||
| if not os.path.exists(model_path): | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| log.error(str(error)) | |||
| raise error | |||
| try: | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| except ModuleNotFoundError: | |||
| error_msg = \ | |||
| "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| log.error(str(error)) | |||
| raise error from None | |||
| return model | |||
| class TFGraphParser(GraphParser): | |||
| """Define TF graph parser.""" | |||
| @classmethod | |||
| @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.") | |||
| @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parse TF Computational Graph File (.pb) | |||
| @@ -74,8 +37,9 @@ class TFGraphParser(GraphParser): | |||
| object, ONNX model. | |||
| """ | |||
| from .onnx_utils import convert_tf_graph_to_onnx | |||
| onnx_utils = import_module( | |||
| "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") | |||
| convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| if not os.path.exists(model_path): | |||