Browse Source

!945 [ADD & BUGFIX] catch exception raise from tensorflow in MindConverter module.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a7bc02d6d1
7 changed files with 111 additions and 59 deletions
  1. +27
    -2
      mindinsight/mindconverter/common/exceptions.py
  2. +2
    -2
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +12
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  4. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  5. +3
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  6. +60
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  7. +6
    -42
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 27
- 2
mindinsight/mindconverter/common/exceptions.py View File

@@ -15,6 +15,7 @@
"""Define custom exception."""
import sys
from enum import unique
from importlib import import_module

from lib2to3.pgen2 import parse
from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDAbsentError
@@ -212,7 +213,8 @@ class GraphInitFail(MindConverterException):
ModuleNotFoundError,
ModelNotSupport,
TypeError,
ZeroDivisionError)
ZeroDivisionError,
RuntimeError)
return except_source

@classmethod
@@ -294,7 +296,7 @@ class ModelNotSupport(MindConverterException):
return except_source

@classmethod
def check_except(cls, msg):
def check_except_pytorch(cls, msg):
"""Check except."""

def decorator(func):
@@ -310,6 +312,29 @@ class ModelNotSupport(MindConverterException):
return _f
return decorator

@classmethod
def check_except_tf(cls, msg):
"""Check except."""
tf_error_module = import_module('tensorflow.python.framework.errors_impl')
tf_error = getattr(tf_error_module, 'OpError')

cls._error = cls.raise_from() + (tf_error,)

def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls._error as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output

return _f

return decorator


class NodeInputMissing(MindConverterException):
"""The node input missing error."""


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -119,10 +119,10 @@ def _extract_model_name(model_path):
return model_name


@torch_installation_validation
@GraphInitFail.check_except_pytorch("Error occurred when init graph object.")
@TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.")
@torch_installation_validation
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
"""
@@ -153,10 +153,10 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
report_folder=report_folder)


@tf_installation_validation
@GraphInitFail.check_except_tf("Error occurred when init graph object.")
@TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_tf("Error occurred when save source files.")
@tf_installation_validation
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):


+ 12
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -14,13 +14,10 @@
# ==============================================================================
"""Graph associated definition module."""

__all__ = ["GraphFactory", "PyTorchGraphNode"]
__all__ = ["GraphFactory"]
from importlib import import_module

from .base import Graph
from .pytorch_graph import PyTorchGraph
from .pytorch_graph_node import PyTorchGraphNode
from .onnx_graph import OnnxGraph
from .onnx_graph_node import OnnxGraphNode


class GraphFactory:
@@ -43,7 +40,13 @@ class GraphFactory:
Graph, graph instance.
"""
if all([input_nodes, output_nodes]):
return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape)
onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

pytorch_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph')
pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph')
return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape)

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -19,7 +19,7 @@ from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode
from .graph_parser import TFGraphParser
from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader

NONE_SCOPE_OP = {


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -20,9 +20,9 @@ from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .graph_parser import PyTorchGraphParser
from .pytorch_graph_parser import PyTorchGraphParser

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

NONE_SCOPE_OP = {
@@ -46,7 +46,7 @@ def normalize_scope_name(node):
"""
global NONE_SCOPE_OP

name = node.scopeName().split(SEPARATOR_IN_SCOPE)
name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = []
for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0]


+ 60
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Third party graph parser."""
import os
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupport


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.")
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.

Args:
model_path (str): Model file path.

Returns:
object, torch model.
"""
torch = import_module("torch")

if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

return model

mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py → mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -14,55 +14,18 @@
# ==============================================================================
"""Third party graph parser."""
import os
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupport


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except("Error occurs in loading model, make sure model.pth correct.")
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.

Args:
model_path (str): Model file path.

Returns:
object, torch model.
"""
import torch

if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

return model


class TFGraphParser(GraphParser):
"""Define TF graph parser."""

@classmethod
@ModelNotSupport.check_except("Error occurs in loading model, make sure model.pb correct.")
@ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.")
def parse(cls, model_path: str, **kwargs):
"""
Parse TF Computational Graph File (.pb)
@@ -74,8 +37,9 @@ class TFGraphParser(GraphParser):
object, ONNX model.
"""

from .onnx_utils import convert_tf_graph_to_onnx

onnx_utils = import_module(
"mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils")
convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx")
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
if not os.path.exists(model_path):

Loading…
Cancel
Save