diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 9853bee9..15e42e42 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -15,6 +15,7 @@ """Define custom exception.""" import sys from enum import unique +from importlib import import_module from lib2to3.pgen2 import parse from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError @@ -212,7 +213,8 @@ class GraphInitFail(MindConverterException): ModuleNotFoundError, ModelNotSupport, TypeError, - ZeroDivisionError) + ZeroDivisionError, + RuntimeError) return except_source @classmethod @@ -294,7 +296,7 @@ class ModelNotSupport(MindConverterException): return except_source @classmethod - def check_except(cls, msg): + def check_except_pytorch(cls, msg): """Check except.""" def decorator(func): @@ -310,6 +312,29 @@ class ModelNotSupport(MindConverterException): return _f return decorator + @classmethod + def check_except_tf(cls, msg): + """Check except.""" + tf_error_module = import_module('tensorflow.python.framework.errors_impl') + tf_error = getattr(tf_error_module, 'OpError') + + cls._error = cls.raise_from() + (tf_error,) + + def decorator(func): + def _f(arch, model_path, **kwargs): + try: + output = func(arch, model_path=model_path, **kwargs) + except cls._error as e: + error = cls(msg=msg) + log.error(msg) + log.exception(e) + raise error from e + return output + + return _f + + return decorator + class NodeInputMissing(MindConverterException): """The node input missing error.""" diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 05355d05..151b4e01 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -119,10 +119,10 @@ def _extract_model_name(model_path): return model_name +@torch_installation_validation @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") -@torch_installation_validation def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None): """ @@ -153,10 +153,10 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, report_folder=report_folder) +@tf_installation_validation @GraphInitFail.check_except_tf("Error occurred when init graph object.") @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") -@tf_installation_validation def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py index da62def2..9dd5ba22 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py @@ -14,13 +14,10 @@ # ============================================================================== """Graph associated definition module.""" -__all__ = ["GraphFactory", "PyTorchGraphNode"] +__all__ = ["GraphFactory"] +from importlib import import_module from .base import Graph -from .pytorch_graph import PyTorchGraph -from .pytorch_graph_node import PyTorchGraphNode -from .onnx_graph import OnnxGraph -from .onnx_graph_node import OnnxGraphNode class GraphFactory: @@ -43,7 +40,13 @@ class GraphFactory: Graph, graph instance. """ if all([input_nodes, output_nodes]): - return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes, - output_nodes=output_nodes, sample_shape=sample_shape) - - return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) + onnx_graph_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') + onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') + return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, + output_nodes=output_nodes, sample_shape=sample_shape) + + pytorch_graph_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') + pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') + return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index f4befbc8..1114b95c 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -19,7 +19,7 @@ from mindinsight.mindconverter.common.log import logger as log from .base import Graph from .input_node import InputNode from .onnx_graph_node import OnnxGraphNode -from .graph_parser import TFGraphParser +from .tf_graph_parser import TFGraphParser from .onnx_utils import OnnxDataLoader NONE_SCOPE_OP = { diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index ae75de5c..e47519a1 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -20,9 +20,9 @@ from mindinsight.mindconverter.common.log import logger as log from .base import Graph from .input_node import InputNode from .pytorch_graph_node import PyTorchGraphNode -from .graph_parser import PyTorchGraphParser +from .pytorch_graph_parser import PyTorchGraphParser -from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE +from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID from ..constant import LEFT_BUCKET, RIGHT_BUCKET NONE_SCOPE_OP = { @@ -46,7 +46,7 @@ def normalize_scope_name(node): """ global NONE_SCOPE_OP - name = node.scopeName().split(SEPARATOR_IN_SCOPE) + name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) scopes = [] for segment in name: segment = segment.split(LINK_IN_SCOPE)[0] diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py new file mode 100644 index 00000000..a6c3a47e --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Third party graph parser.""" +import os +from importlib import import_module + +from mindinsight.mindconverter.common.log import logger as log +from .base import GraphParser +from ...common.exceptions import ModelNotSupport + + +class PyTorchGraphParser(GraphParser): + """Define pytorch graph parser.""" + + @classmethod + @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") + def parse(cls, model_path: str, **kwargs): + """ + Parser pytorch graph. + + Args: + model_path (str): Model file path. + + Returns: + object, torch model. + """ + torch = import_module("torch") + + if not os.path.exists(model_path): + error = FileNotFoundError("`model_path` must be assigned with " + "an existed file path.") + log.error(str(error)) + raise error + + try: + if torch.cuda.is_available(): + model = torch.load(f=model_path) + else: + model = torch.load(f=model_path, map_location="cpu") + except ModuleNotFoundError: + error_msg = \ + "Cannot find model scripts in system path, " \ + "set `--project_path` to the path of model scripts folder correctly." + error = ModuleNotFoundError(error_msg) + log.error(str(error)) + raise error from None + + return model diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py similarity index 61% rename from mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py rename to mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py index 8c5f87ea..26935a75 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py @@ -14,55 +14,18 @@ # ============================================================================== """Third party graph parser.""" import os +from importlib import import_module + from mindinsight.mindconverter.common.log import logger as log from .base import GraphParser from ...common.exceptions import ModelNotSupport -class PyTorchGraphParser(GraphParser): - """Define pytorch graph parser.""" - - @classmethod - @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.") - def parse(cls, model_path: str, **kwargs): - """ - Parser pytorch graph. - - Args: - model_path (str): Model file path. - - Returns: - object, torch model. - """ - import torch - - if not os.path.exists(model_path): - error = FileNotFoundError("`model_path` must be assigned with " - "an existed file path.") - log.error(str(error)) - raise error - - try: - if torch.cuda.is_available(): - model = torch.load(f=model_path) - else: - model = torch.load(f=model_path, map_location="cpu") - except ModuleNotFoundError: - error_msg = \ - "Cannot find model scripts in system path, " \ - "set `--project_path` to the path of model scripts folder correctly." - error = ModuleNotFoundError(error_msg) - log.error(str(error)) - raise error from None - - return model - - class TFGraphParser(GraphParser): """Define TF graph parser.""" @classmethod - @ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.") + @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") def parse(cls, model_path: str, **kwargs): """ Parse TF Computational Graph File (.pb) @@ -74,8 +37,9 @@ class TFGraphParser(GraphParser): object, ONNX model. """ - from .onnx_utils import convert_tf_graph_to_onnx - + onnx_utils = import_module( + "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") + convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") tf_input_nodes = kwargs.get('input_nodes') tf_output_nodes = kwargs.get('output_nodes') if not os.path.exists(model_path):