Browse Source

Add module rename function and add exception define.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
d58a5bcbbb
13 changed files with 249 additions and 321 deletions
  1. +54
    -51
      mindinsight/mindconverter/cli.py
  2. +114
    -185
      mindinsight/mindconverter/common/exceptions.py
  3. +3
    -1
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  4. +25
    -14
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +1
    -0
      mindinsight/mindconverter/graph_based_converter/generator/args_translator.py
  6. +2
    -17
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +1
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  8. +15
    -6
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  9. +22
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  10. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +6
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  12. +4
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py
  13. +0
    -35
      tests/st/func/mindconverter/test_converter.py

+ 54
- 51
mindinsight/mindconverter/cli.py View File

@@ -170,6 +170,9 @@ class InFileAction(argparse.Action):
if not os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a file')

if not os.path.basename(outfile_dir).endswith("py"):
parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')

setattr(namespace, self.dest, outfile_dir)


@@ -282,32 +285,32 @@ class NodeAction(argparse.Action):


parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
allow_abbrev=False)
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
allow_abbrev=False)

parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))

parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=False,
default=None,
help="""
'--in_file',
type=str,
action=InFileAction,
required=False,
default=None,
help="""
Specify path for script file to use AST schema to
do script conversation.
""")

parser.add_argument(
'--model_file',
type=str,
action=ModelFileAction,
required=False,
help="""
'--model_file',
type=str,
action=ModelFileAction,
required=False,
help="""
PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When
`--in_file` and `--model_file` are both provided,
@@ -315,12 +318,12 @@ parser.add_argument(
""")

parser.add_argument(
'--shape',
type=str,
action=ShapeAction,
default=None,
required=False,
help="""
'--shape',
type=str,
action=ShapeAction,
default=None,
required=False,
help="""
Optional, expected input tensor shape of
`--model_file`. It's required when use graph based
schema.
@@ -328,55 +331,55 @@ parser.add_argument(
""")

parser.add_argument(
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --input_nodes input_1:0,input_2:0
""")

parser.add_argument(
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --output_nodes output_1:0,output_2:0
""")

parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Optional, specify path for converted script file
directory. Default output directory is `output` folder
in the current working directory.
""")

parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=None,
help="""
'--report',
type=str,
action=LogFileAction,
default=None,
help="""
Optional, specify report directory. Default is
converted script directory.
""")

parser.add_argument(
'--project_path',
type=str,
action=ProjectPathAction,
required=False,
default=None,
help="""
'--project_path',
type=str,
action=ProjectPathAction,
required=False,
default=None,
help="""
Optional, PyTorch scripts project path. If PyTorch
project is not in PYTHONPATH, please assign
`--project_path` when use graph based schema.


+ 114
- 185
mindinsight/mindconverter/common/exceptions.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Define custom exception."""
import abc
import sys
from enum import unique
from importlib import import_module
@@ -23,7 +24,7 @@ from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDA
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console

from mindinsight.utils.constant import ScriptConverterErrors
from mindinsight.utils.exceptions import MindInsightException, ParamMissError
from mindinsight.utils.exceptions import MindInsightException


@unique
@@ -40,12 +41,16 @@ class ConverterErrors(ScriptConverterErrors):
SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11
INPUT_SHAPE_ERROR = 12
TF_RUNTIME_ERROR = 13

BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100
TREE_CREATE_FAIL = 200
SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400
SUB_GRAPH_SEARCHING_FAIL = 500
MODEL_LOADING_FAIL = 600


class ScriptNotSupport(MindInsightException):
@@ -80,7 +85,6 @@ class MindConverterException(Exception):

def __init__(self, **kwargs):
"""Initialization of MindInsightException."""

error = kwargs.get('error', None)
user_msg = kwargs.get('user_msg', '')
debug_msg = kwargs.get('debug_msg', '')
@@ -97,6 +101,9 @@ class MindConverterException(Exception):
def __str__(self):
return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg)

def __repr__(self):
return self.__str__()

def error_code(self):
""""
Calculate error code.
@@ -109,54 +116,59 @@ class MindConverterException(Exception):
Returns:
str, Hex string representing the composed MindConverter error code.
"""

num = 0xFFFF & self.error.value
error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper()))
return error_code

@staticmethod
def raise_from():
@classmethod
@abc.abstractmethod
def raise_from(cls):
"""Raise from below exceptions."""
return None

@classmethod
def check_except_with_print_pytorch(cls, msg):
"""Check except in pytorch."""
def uniform_catcher(cls, msg):
"""Uniform exception catcher."""

def decorator(func):
def _f(graph_path, sample_shape, output_folder, report_folder):
def _f(*args, **kwargs):
try:
func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder)
res = func(*args, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
log_console.error(detail_info)
log.exception(e)
sys.exit(-1)
sys.exit(0)
except ModuleNotFoundError as e:
detail_info = f"Error detail: Required package not found, please check the runtime environment."
log_console.error(str(e))
log_console.error(detail_info)
log.exception(e)
sys.exit(0)
return res

return _f

return decorator

@classmethod
def check_except_with_print_tf(cls, msg):
"""Check except in tf."""
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(graph_path, sample_shape,
input_nodes, output_nodes,
output_folder, report_folder):
def _f(*args, **kwargs):
try:
func(graph_path=graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)
output = func(*args, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
log_console.error(detail_info)
log.error(msg)
log.exception(e)
raise cls(msg=msg)
except Exception as e:
log.error(msg)
log.exception(e)
sys.exit(-1)
raise e
return output

return _f

@@ -170,31 +182,12 @@ class BaseConverterFail(MindConverterException):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (UnknownModel,
ParamMissError)
except_source = Exception, cls
return except_source

@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(file_config):
try:
func(file_config=file_config)
except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
log_console.error(detail_info)
log.exception(e)
sys.exit(-1)
return _f
return decorator


class UnknownModel(MindConverterException):
"""The unknown model error."""
@@ -203,6 +196,10 @@ class UnknownModel(MindConverterException):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg)

@classmethod
def raise_from(cls):
return cls


class GraphInitFail(MindConverterException):
"""The graph init fail error."""
@@ -211,27 +208,19 @@ class GraphInitFail(MindConverterException):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', ''))

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (FileNotFoundError,
ModuleNotFoundError,
ModelNotSupport,
SubGraphSearchingFail,
TypeError,
ZeroDivisionError,
RuntimeError)
RuntimeError,
cls)
return except_source

@classmethod
def check_except_pytorch(cls, msg):
"""Check except for pytorch."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)


class TreeCreateFail(MindConverterException):
"""The tree create fail."""
@@ -240,23 +229,13 @@ class TreeCreateFail(MindConverterException):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (NodeInputMissing,
TreeNodeInsertFail)
TreeNodeInsertFail, cls)
return except_source

@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)


class SourceFilesSaveFail(MindConverterException):
"""The source files save fail error."""
@@ -265,25 +244,15 @@ class SourceFilesSaveFail(MindConverterException):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (NodeInputTypeNotSupport,
ScriptGenerateFail,
ReportGenerateFail,
IOError)
IOError, cls)
return except_source

@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)


class ModelNotSupport(MindConverterException):
"""The model not support error."""
@@ -293,55 +262,32 @@ class ModelNotSupport(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (RuntimeError,
ModuleNotFoundError,
ValueError,
AssertionError,
TypeError,
OSError,
ZeroDivisionError)
ZeroDivisionError, cls)
return except_source

@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator
class TfRuntimeError(MindConverterException):
"""Catch tf runtime error."""

def __init__(self, msg):
super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)

@classmethod
def check_except_tf(cls, msg):
"""Check except."""
def raise_from(cls):
tf_error_module = import_module('tensorflow.python.framework.errors_impl')
tf_error = getattr(tf_error_module, 'OpError')

cls._error = cls.raise_from() + (tf_error,)

def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls._error as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output

return _f

return decorator
return tf_error, ValueError, RuntimeError, cls


class NodeInputMissing(MindConverterException):
@@ -352,6 +298,10 @@ class NodeInputMissing(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)

@classmethod
def raise_from(cls):
return ValueError, IndexError, KeyError, AttributeError, cls


class TreeNodeInsertFail(MindConverterException):
"""The tree node create fail error."""
@@ -361,32 +311,15 @@ class TreeNodeInsertFail(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (OSError,
DuplicatedNodeIdError,
MultipleRootError,
NodeIDAbsentError)
NodeIDAbsentError, cls)
return except_source

@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, graph):
try:
output = func(arch, graph=graph)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator


class NodeInputTypeNotSupport(MindConverterException):
"""The node input type NOT support error."""
@@ -396,6 +329,10 @@ class NodeInputTypeNotSupport(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)

@classmethod
def raise_from(cls):
return ValueError, TypeError, IndexError, cls


class ScriptGenerateFail(MindConverterException):
"""The script generate fail error."""
@@ -405,31 +342,14 @@ class ScriptGenerateFail(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (RuntimeError,
parse.ParseError,
AttributeError)
AttributeError, cls)
return except_source

@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, mapper):
try:
output = func(arch, mapper=mapper)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator


class ReportGenerateFail(MindConverterException):
"""The report generate fail error."""
@@ -439,28 +359,24 @@ class ReportGenerateFail(MindConverterException):
user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)

@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = ZeroDivisionError
return except_source
return ZeroDivisionError, cls

@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, mapper):
try:
output = func(arch, mapper=mapper)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator
class SubGraphSearchingFail(MindConverterException):
"""Sub-graph searching exception."""
def __init__(self, msg):
super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value,
user_msg=msg)
@classmethod
def raise_from(cls):
"""Define exception in sub-graph searching module."""
return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls


class GeneratorFail(MindConverterException):
@@ -470,10 +386,23 @@ class GeneratorFail(MindConverterException):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (ValueError,
TypeError,
cls)
except_source = (ValueError, TypeError, cls)
return except_source


class ModelLoadingFail(MindConverterException):
"""Model loading fail."""

def __init__(self, msg):
super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR,
cls_code=ConverterErrors.MODEL_LOADING_FAIL.value,
user_msg=msg)

@classmethod
def raise_from(cls):
"""Define exception when model loading fail."""
return ValueError, cls

+ 3
- 1
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -135,9 +135,11 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited):
return False
return True


def get_dict_key_by_value(val, dic):
"""
Return the first appeared key of a dictionay by given value.
Return the first appeared key of a dictionary by given value.

Args:
val (Any): Value of the key.


+ 25
- 14
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -16,6 +16,7 @@
import os
import re
import argparse
import sys
from importlib import import_module
from importlib.util import find_spec

@@ -25,12 +26,11 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \
BaseConverterFail, UnknownModel
BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError
from mindinsight.utils.exceptions import ParamMissError


permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

@@ -71,8 +71,10 @@ def torch_installation_validation(func):
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(str(error))
log.exception(error)
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder)
@@ -103,7 +105,10 @@ def tf_installation_validation(func):
f"based scripts converter for TensorFlow conversion."
)
log.error(str(error))
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)

onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
ort = import_module("onnxruntime")
@@ -117,7 +122,10 @@ def tf_installation_validation(func):
f"based scripts converter for TensorFlow conversion."
)
log.error(str(error))
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder,
@@ -142,9 +150,10 @@ def _extract_model_name(model_path):


@torch_installation_validation
@GraphInitFail.check_except_pytorch("Error occurred when init graph object.")
@TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.")
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
"""
@@ -176,9 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,


@tf_installation_validation
@GraphInitFail.check_except_tf("Error occurred when init graph object.")
@TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_tf("Error occurred when save source files.")
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
@@ -210,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)


@BaseConverterFail.check_except("Failed to start base converter.")
@BaseConverterFail.uniform_catcher("Failed to start base converter.")
def main_graph_base_converter(file_config):
"""
The entrance for converter, script files will be converted.


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/generator/args_translator.py View File

@@ -201,6 +201,7 @@ class ArgsTranslation:

class ArgsTranslationHelper:
"""Define operations related to ArgsTranslation instances."""

@staticmethod
def find_formal_args_in_modules(args_translators):
"""


+ 2
- 17
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -541,7 +541,7 @@ class ModuleStruct:
for output in output_list:
(provider_succ, provider_closet_opt_var) = output
if provider_closet_opt_var in struct.matched_inputs:
continue # skip repeat
continue # skip repeat
if provider_succ == struct.onnx_name:
struct.matched_inputs.append(provider_closet_opt_var)

@@ -695,20 +695,5 @@ class ModuleStruct:
"""Register submodule outputs to this module's return."""
submodule_returns = md_struct.get_returned_opt_var_name()
submodule_opt_var_name = md_struct.ms_opt_var_name
for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns:
for (submodule_ext_succ, _, ith_output) in submodule_returns:
self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output)
# edit external succ 's inputs in parent module
ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ)
ext_node_parent = ext_node.parent_module_struct
while ext_node_parent != self.parent_module_struct:
ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name
ext_node_parent = ext_node_parent.parent_module_struct

# need find the prec_name?
for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items():
if isinstance(opt_var_name, str):
if opt_var_name == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)
if isinstance(opt_var_name, tuple):
if opt_var_name[0] == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)

+ 1
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -34,6 +34,7 @@ class Pattern:
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score.
self.additional_score = 0
self.know_module_name = None

def insert(self, idx, seq_len):
"""


+ 15
- 6
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -18,7 +18,7 @@ import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name
from .known_module_name import BUILT_IN_MODULE_NAME
from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching
@@ -85,13 +85,15 @@ def _is_valid_pattern(pattern, dag):
return True


def generate_module_name(pattern):
def match_known_module_name(pattern):
"""
Generate module name.
Matching with know module name.

Args:
pattern (Pattern): To be replaced pattern.

Returns:
str, matched module name, return None if not matched.
"""
matched_result = []
for ptn, module_name in BUILT_IN_MODULE_NAME.items():
@@ -109,7 +111,11 @@ def generate_module_name(pattern):
module_name = f"{module_name}{used_module_name[pattern.pattern]}"
used_module_name[pattern.pattern] += 1
return module_name
return None


def generate_module_name():
"""Generate module name."""
global global_idx
name = f"Module{global_idx}"
global_idx += 1
@@ -439,13 +445,16 @@ class SearchPath:
to recover the sequence.
"""
if self.pattern.pattern not in scope_name_mapping:
module_name = generate_module_name(self.pattern)
scope_name_mapping[self.pattern.pattern] = module_name
module_name = generate_module_name()
known_module_name = match_known_module_name(self.pattern)
scope_name_mapping[self.pattern] = module_name
module_name_to_src[module_name] = self.pattern.pattern
else:
module_name = scope_name_mapping[self.pattern.pattern]
known_module_name = module_name_to_src[module_name].known_module_name
self.pattern.module_name = module_name
if is_built_in_module_name(module_name):
self.pattern.known_module_name = known_module_name
if known_module_name:
self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length)
topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl)
return topo_order, inverted_index


+ 22
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -18,8 +18,10 @@ from typing import Dict, List

from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT
from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingFail


def _is_satisfied(path):
@@ -249,6 +251,23 @@ def validate_topo_order_succession():
return True


def _add_known_module_name(search_path):
"""
Add known module name to GlobalContext.

Args:
search_path (SearchPath): Search path.

"""
ctx = GlobalContext()
if search_path.pattern.known_module_name:
ctx.known_module_name[search_path.pattern.module_name] = search_path.pattern.known_module_name
for it in search_path.recursion_path:
if it.pattern.known_module_name:
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name


@SubGraphSearchingFail.check_except("Sub-Graph searching fail.")
def generate_scope_name(data_loader):
"""
Generate scope name according to computation graph.
@@ -270,6 +289,9 @@ def generate_scope_name(data_loader):
if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict):
topo_order_with_scope_name_list = flatten_graph(init_dag)

if result:
_add_known_module_name(result)

except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag)
return topo_order_with_scope_name_list

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport
from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -308,7 +308,7 @@ class OnnxDataLoader:
w = int(match.group('w'))
c = int(match.group('c'))
if [h, w, c] != list(self.graph_input_shape)[1:4]:
raise ValueError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
return True
return False



+ 6
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -25,7 +25,9 @@ class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
@ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.")
@ModelNotSupport.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.
@@ -50,11 +52,9 @@ class PyTorchGraphParser(GraphParser):
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error_msg = "Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None
raise error

return model

+ 4
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -25,7 +25,9 @@ class TFGraphParser(GraphParser):
"""Define TF graph parser."""

@classmethod
@ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.")
@ModelNotSupport.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):
"""
Parse TF Computational Graph File (.pb)
@@ -36,7 +38,6 @@ class TFGraphParser(GraphParser):
Returns:
object, ONNX model.
"""

onnx_utils = import_module(
"mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils")
convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx")
@@ -50,6 +51,5 @@ class TFGraphParser(GraphParser):

model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
)
model_outputs=tf_output_nodes)
return model

+ 0
- 35
tests/st/func/mindconverter/test_converter.py View File

@@ -21,13 +21,10 @@ Usage:
"""
import difflib
import os
import re
import sys

import pytest

from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter


@pytest.mark.usefixtures('create_output_dir')
@@ -82,35 +79,3 @@ class TestConverter:

converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
assert converted_ratio >= 80

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_main_graph_based_converter(self, output):
"""Test main graph based converter."""
pytorch_filename = "resnet50.pth"
expected_model_filename = "resnet50.py"
expected_report_filename = "report_of_resnet50.txt"
file_config = {
'model_file': os.path.join(self.pytorch_dir, pytorch_filename),
'shape': (1, 3, 224, 224),
'outfile_dir': output,
'report_dir': output
}
with pytest.raises(ValueError) as e:
main_graph_base_converter(file_config=file_config)

assert os.path.isfile(os.path.join(output, expected_model_filename))
assert os.path.isfile(os.path.join(output, expected_report_filename))

with open(os.path.join(output, expected_report_filename)) as converted_r:
converted_report = converted_r.readlines()
converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1])

assert converted_rate[0] == '100.00%'

exec_msg = e.value.args[0]
assert exec_msg == "torch.__spec__ is None"

Loading…
Cancel
Save