Browse Source

!945 [ADD & BUGFIX] catch exception raise from tensorflow in MindConverter module.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a7bc02d6d1
7 changed files with 111 additions and 59 deletions
  1. +27
    -2
      mindinsight/mindconverter/common/exceptions.py
  2. +2
    -2
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +12
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  4. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  5. +3
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  6. +60
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  7. +6
    -42
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 27
- 2
mindinsight/mindconverter/common/exceptions.py View File

@@ -15,6 +15,7 @@
"""Define custom exception.""" """Define custom exception."""
import sys import sys
from enum import unique from enum import unique
from importlib import import_module


from lib2to3.pgen2 import parse from lib2to3.pgen2 import parse
from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError
@@ -212,7 +213,8 @@ class GraphInitFail(MindConverterException):
ModuleNotFoundError, ModuleNotFoundError,
ModelNotSupport, ModelNotSupport,
TypeError, TypeError,
ZeroDivisionError)
ZeroDivisionError,
RuntimeError)
return except_source return except_source


@classmethod @classmethod
@@ -294,7 +296,7 @@ class ModelNotSupport(MindConverterException):
return except_source return except_source


@classmethod @classmethod
def check_except(cls, msg):
def check_except_pytorch(cls, msg):
"""Check except.""" """Check except."""


def decorator(func): def decorator(func):
@@ -310,6 +312,29 @@ class ModelNotSupport(MindConverterException):
return _f return _f
return decorator return decorator


@classmethod
def check_except_tf(cls, msg):
"""Check except."""
tf_error_module = import_module('tensorflow.python.framework.errors_impl')
tf_error = getattr(tf_error_module, 'OpError')

cls._error = cls.raise_from() + (tf_error,)

def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls._error as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output

return _f

return decorator



class NodeInputMissing(MindConverterException): class NodeInputMissing(MindConverterException):
"""The node input missing error.""" """The node input missing error."""


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -119,10 +119,10 @@ def _extract_model_name(model_path):
return model_name return model_name




@torch_installation_validation
@GraphInitFail.check_except_pytorch("Error occurred when init graph object.") @GraphInitFail.check_except_pytorch("Error occurred when init graph object.")
@TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.")
@torch_installation_validation
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
""" """
@@ -153,10 +153,10 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
report_folder=report_folder) report_folder=report_folder)




@tf_installation_validation
@GraphInitFail.check_except_tf("Error occurred when init graph object.") @GraphInitFail.check_except_tf("Error occurred when init graph object.")
@TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.")
@tf_installation_validation
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str, input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):


+ 12
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -14,13 +14,10 @@
# ============================================================================== # ==============================================================================
"""Graph associated definition module.""" """Graph associated definition module."""


__all__ = ["GraphFactory", "PyTorchGraphNode"]
__all__ = ["GraphFactory"]
from importlib import import_module


from .base import Graph from .base import Graph
from .pytorch_graph import PyTorchGraph
from .pytorch_graph_node import PyTorchGraphNode
from .onnx_graph import OnnxGraph
from .onnx_graph_node import OnnxGraphNode




class GraphFactory: class GraphFactory:
@@ -43,7 +40,13 @@ class GraphFactory:
Graph, graph instance. Graph, graph instance.
""" """
if all([input_nodes, output_nodes]): if all([input_nodes, output_nodes]):
return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape)
onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

pytorch_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph')
pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph')
return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape)

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -19,7 +19,7 @@ from mindinsight.mindconverter.common.log import logger as log
from .base import Graph from .base import Graph
from .input_node import InputNode from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode from .onnx_graph_node import OnnxGraphNode
from .graph_parser import TFGraphParser
from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader from .onnx_utils import OnnxDataLoader


NONE_SCOPE_OP = { NONE_SCOPE_OP = {


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -20,9 +20,9 @@ from mindinsight.mindconverter.common.log import logger as log
from .base import Graph from .base import Graph
from .input_node import InputNode from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_node import PyTorchGraphNode
from .graph_parser import PyTorchGraphParser
from .pytorch_graph_parser import PyTorchGraphParser


from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID
from ..constant import LEFT_BUCKET, RIGHT_BUCKET from ..constant import LEFT_BUCKET, RIGHT_BUCKET


NONE_SCOPE_OP = { NONE_SCOPE_OP = {
@@ -46,7 +46,7 @@ def normalize_scope_name(node):
""" """
global NONE_SCOPE_OP global NONE_SCOPE_OP


name = node.scopeName().split(SEPARATOR_IN_SCOPE)
name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = [] scopes = []
for segment in name: for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0] segment = segment.split(LINK_IN_SCOPE)[0]


+ 60
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Third party graph parser."""
import os
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupport


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.")
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.

Args:
model_path (str): Model file path.

Returns:
object, torch model.
"""
torch = import_module("torch")

if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

return model

mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py → mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -14,55 +14,18 @@
# ============================================================================== # ==============================================================================
"""Third party graph parser.""" """Third party graph parser."""
import os import os
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser from .base import GraphParser
from ...common.exceptions import ModelNotSupport from ...common.exceptions import ModelNotSupport




class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.")
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.

Args:
model_path (str): Model file path.

Returns:
object, torch model.
"""
import torch

if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

return model


class TFGraphParser(GraphParser): class TFGraphParser(GraphParser):
"""Define TF graph parser.""" """Define TF graph parser."""


@classmethod @classmethod
@ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.")
@ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.")
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):
""" """
Parse TF Computational Graph File (.pb) Parse TF Computational Graph File (.pb)
@@ -74,8 +37,9 @@ class TFGraphParser(GraphParser):
object, ONNX model. object, ONNX model.
""" """


from .onnx_utils import convert_tf_graph_to_onnx

onnx_utils = import_module(
"mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils")
convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx")
tf_input_nodes = kwargs.get('input_nodes') tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes') tf_output_nodes = kwargs.get('output_nodes')
if not os.path.exists(model_path): if not os.path.exists(model_path):

Loading…
Cancel
Save