Browse Source

Updates on merging generator to MindConverter, and unnecessary logic clean

tags/v1.1.0
liangtianshu 5 years ago
parent
commit
666a0d35a3
7 changed files with 120 additions and 139 deletions
  1. +29
    -0
      mindinsight/mindconverter/common/exceptions.py
  2. +14
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +8
    -11
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +31
    -82
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  5. +7
    -3
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  6. +29
    -27
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  7. +2
    -16
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py

+ 29
- 0
mindinsight/mindconverter/common/exceptions.py View File

@@ -39,11 +39,13 @@ class ConverterErrors(ScriptConverterErrors):
MODEL_NOT_SUPPORT = 8
SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11

BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100
TREE_CREATE_FAIL = 200
SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400


class ScriptNotSupport(MindInsightException):
@@ -163,6 +165,7 @@ class MindConverterException(Exception):

class BaseConverterFail(MindConverterException):
"""Base converter failed."""

def __init__(self, msg):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg)
@@ -195,6 +198,7 @@ class BaseConverterFail(MindConverterException):

class UnknownModel(MindConverterException):
"""The unknown model error."""

def __init__(self, msg):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg)
@@ -202,6 +206,7 @@ class UnknownModel(MindConverterException):

class GraphInitFail(MindConverterException):
"""The graph init fail error."""

def __init__(self, **kwargs):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', ''))
@@ -230,6 +235,7 @@ class GraphInitFail(MindConverterException):

class TreeCreateFail(MindConverterException):
"""The tree create fail."""

def __init__(self, msg):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg)
@@ -254,6 +260,7 @@ class TreeCreateFail(MindConverterException):

class SourceFilesSaveFail(MindConverterException):
"""The source files save fail error."""

def __init__(self, msg):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg)
@@ -280,6 +287,7 @@ class SourceFilesSaveFail(MindConverterException):

class ModelNotSupport(MindConverterException):
"""The model not support error."""

def __init__(self, msg):
super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
user_msg=msg,
@@ -338,6 +346,7 @@ class ModelNotSupport(MindConverterException):

class NodeInputMissing(MindConverterException):
"""The node input missing error."""

def __init__(self, msg):
super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING,
user_msg=msg,
@@ -346,6 +355,7 @@ class NodeInputMissing(MindConverterException):

class TreeNodeInsertFail(MindConverterException):
"""The tree node create fail error."""

def __init__(self, msg):
super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL,
user_msg=msg,
@@ -380,6 +390,7 @@ class TreeNodeInsertFail(MindConverterException):

class NodeInputTypeNotSupport(MindConverterException):
"""The node input type NOT support error."""

def __init__(self, msg):
super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
user_msg=msg,
@@ -388,6 +399,7 @@ class NodeInputTypeNotSupport(MindConverterException):

class ScriptGenerateFail(MindConverterException):
"""The script generate fail error."""

def __init__(self, msg):
super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL,
user_msg=msg,
@@ -421,6 +433,7 @@ class ScriptGenerateFail(MindConverterException):

class ReportGenerateFail(MindConverterException):
"""The report generate fail error."""

def __init__(self, msg):
super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL,
user_msg=msg,
@@ -448,3 +461,19 @@ class ReportGenerateFail(MindConverterException):
return output
return _f
return decorator


class GeneratorFail(MindConverterException):
"""The Generator fail error."""

def __init__(self, msg):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value)
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (ValueError,
TypeError,
cls)
return except_source

+ 14
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -75,3 +75,17 @@ class InputType(Enum):
class FrameworkType(Enum):
PYTORCH = 0
TENSORFLOW = 1


def get_imported_module():
"""
Generate imported module header.

Returns:
str, imported module.
"""
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 8
- 11
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -20,7 +20,8 @@ from importlib import import_module
from importlib.util import find_spec

import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
@@ -29,6 +30,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreat
BaseConverterFail, UnknownModel
from mindinsight.utils.exceptions import ParamMissError


permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

@@ -194,23 +196,18 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
"""
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'),
"batch_add_nodes")
# Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)

hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj)

generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)
hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder,
scope_name_map=scope_name_map)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)


@BaseConverterFail.check_except("Failed to start base converter.")


+ 31
- 82
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -16,23 +16,17 @@
import copy
from collections import OrderedDict

from yapf.yapflib.yapf_api import FormatCode

from .scope_utils import Scope
from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorFail
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT


class Singleton(type):
"""Metaclass to make the generator to be single instance."""
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator


class CodeStruct:
@@ -88,14 +82,9 @@ class CodeStruct:
repeated_submodules (dict): The dict contains all submodules which use repeatedly.
Can get this dict from generator.
"""
# Define tmp var for code generation.
opt_var_name_records = dict() # now only support multiple outputs within same scope.
return_value_records = dict() # save returned values for successor nodes/modules use.

# Define Module header code line below
if md_struct.pattern_id != -1:
class_name = f"Module{md_struct.pattern_id}"
else:
class_name = "Model"
class_name = md_struct.class_name
# define a class declaration
self.new_line = f"class {class_name}(nn.Cell):"

@@ -108,36 +97,18 @@ class CodeStruct:
for formal in md_struct.args_translator.formal_args.keys():
module_def_args.append(formal)

# Collect extra inputs and outputs

# For code line in init & construct blocks
init_lines = list()
cons_lines = list()
for (_, struct) in md_struct.get_generate_order():
if isinstance(struct, NodeStruct): # Generate code line for Node.
code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records)
code_line_construct = struct.code_line_in_construct()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")
# add extra tensor
if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor:
code_extra_tensor = struct.add_extra_tensor()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}")

# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

if struct.successor_nodes_names_external:
for ret_user in struct.successor_nodes_names_external:
if return_value_records.get(ret_user) is not None:
return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])]
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}")

elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct
@@ -149,22 +120,6 @@ class CodeStruct:
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")

# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

# record submodule's local return map for following nodes / submodules use
if struct.external_successor_local_returns_map:
for ret_user, _ in struct.external_successor_local_returns_map.items():
if return_value_records.get(ret_user) is not None:
# mulitple returns of a node may need modifiy the index.
return_value_records[ret_user].append((struct.identifier, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])]
else:
raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.")

@@ -183,8 +138,7 @@ class CodeStruct:
# define returns
returns = []
if md_struct.external_successor_local_returns_map:
ret = list(md_struct.external_successor_local_returns_map.values())
for r in ret:
for r in list(md_struct.external_successor_local_returns_map.values()):
if isinstance(r, tuple): # results return with index nth output
returns.append(r[0])
else:
@@ -197,7 +151,7 @@ class CodeStruct:
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self


class Generator(metaclass=Singleton):
class Generator:
"""The generator controls all routines of code generation."""

def __init__(self):
@@ -217,6 +171,7 @@ class Generator(metaclass=Singleton):
self._global_context.node_struct_collections = self._node_struct_collections
self._repeated_submodules = set()

@GeneratorFail.check_except("Generator occurs an error when forming base submodules.")
def _form_bottom_submodule(self):
"""Form the basic submodules, which only contains nodes."""
# Form module map
@@ -375,7 +330,7 @@ class Generator(metaclass=Singleton):
if scope_path_str == '[]':
continue # is main module, skip
if md_struct.scope_depth != depth:
continue # skip all submodules not at current depth
continue # skip all submodules not at current depth
md_struct_scope = copy.deepcopy(md_struct.identifier)
md_struct_scope.pop()
parent_scope = md_struct_scope
@@ -396,6 +351,7 @@ class Generator(metaclass=Singleton):
self._global_context.add_module_struct(sub.pattern_id, sub)
depth -= 1

@GeneratorFail.check_except("Generator occurs an error when building modules.")
def _recursive_form_module(self):
"""Main routine in generator to build modules from bottom to top."""
# 1. List repeated submodules
@@ -518,6 +474,7 @@ class Generator(metaclass=Singleton):
"""Return all ModuleStructs in this model."""
return self._module_struct_collections

@GeneratorFail.check_except("Generator occurs an error when generating code statements.")
def generate(self):
"""
Generate the final script file.
@@ -527,8 +484,22 @@ class Generator(metaclass=Singleton):
"""
self._form_bottom_submodule()
self._recursive_form_module()
code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)
return code.code_line_list
CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)

outputs = [get_imported_module()]

for code_struct in self._global_context.code_structs.values():
for line in code_struct.code_line_list:
outputs.append(line)

formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value)

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)
del self._global_context

return {"model": (formatted_code, report)}

def get_node_struct(self, node_identifier):
"""
@@ -554,28 +525,6 @@ class Generator(metaclass=Singleton):
"""
return self._module_struct_collections.get(module_identifier, None)

def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id):
"""
Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern.

Args:
pattern_id (int): The pattern id the returned ModuleSturct is.
under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is.

Returns:
list, a list of MoudleStructs has the same pattern_id and the same parents' pattern.
"""
if not pattern_id:
raise ValueError("pattern_id is necessary to get the module struct.")
if not under_parent_pattern_id:
raise ValueError("under_parent_pattern_id is necessary to get the module struct.")
ret = []
md_list = self._global_context.module_structs.get(pattern_id)
for md in md_list:
if md.parent_id == under_parent_pattern_id:
ret.append(md)
return ret

def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False):
"""
Return a list of args translators which belongs to given module structs.


+ 7
- 3
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -201,7 +201,7 @@ class ModuleStruct:

def init_args_translator(self):
"""Initialize the Args Translator for the module."""
var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid)
var_name = self.ms_var_name
self._args_translator = ArgsTranslation(None, var_name, None)

def update_module_fragment(self):
@@ -455,14 +455,18 @@ class ModuleStruct:
"""Return the class name for generating code of this module."""
if self.pattern_id == -1:
return "Model"
return "Module{}".format(self.pattern_id)
if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id))
else:
class_name = "Module{}".format(self.pattern_id)
return class_name

@property
def ms_var_name(self) -> str:
"""Return the variable name for generated code statement of this module."""
if self.pattern_id == -1:
return "Model"
return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower()
return f"{self.class_name}_{self.pattern_uid}".lower()

@property
def ms_opt_var_name(self) -> str:


+ 29
- 27
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -22,6 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ..constant import InputType
from ...common.exceptions import GeneratorFail


class NodeStruct:
@@ -47,22 +48,32 @@ class NodeStruct:
self.node_type = None
self.onnx_name = None
self.onnx_op = None
self.graph_node_ref = None # Our defined GraphNode
self.graph_node_ref = None
self.scope_name = None
self.ms_var_name = None
self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...)
self.ms_opt_var_name = None
self.ms_op = None
self.ready_to_generate = False

self.ms_params = dict() # converted params from mapper
# Define attributes converted from mapper
self.ms_params = dict()
self.ms_settings = dict()
self.ms_weights = dict()
self.ms_inputs = OrderedDict()

self.scope = None # Defined Scope class
self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name
self.matched_inputs = list() # Matched inputs will can be directly used by code line generation
# Defined Scope class
self.scope = None

# Define attributes used for code generation

# key is prec_node_name, value is x; For code line use
self.inputs_in_construct_header = OrderedDict()

# key is prec_node_name, value is its closet opt_var_name
self.inputs_in_parent_module = OrderedDict()

# Matched inputs will can be directly used by code line generation
self.matched_inputs = list()

# initialize funcs.
for arg in args:
@@ -135,6 +146,7 @@ class NodeStruct:
parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier)
self.scope = Scope(parsed_scope)

@GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.")
def init_args_translator(self, translated_args: list):
"""
Initialize the ArgsTranslator for each Node.
@@ -158,6 +170,7 @@ class NodeStruct:
self.ms_op]):
self.ready_to_generate = True

@GeneratorFail.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False):
"""
Pass Node info. to generator NodeStruct.
@@ -314,23 +327,12 @@ class NodeStruct:
return input_name_to_use
return None

def code_line_in_construct(self, inputs=None, in_module_returns=None):
def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block. """
left = self.ms_opt_var_name
if inputs is None:
inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
if self.inputs_in_construct_header.get(prec_node):
inputs.append(self.inputs_in_construct_header.get(prec_node))
elif self._check_target_node_internal(prec_node):
inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name)
elif self.inputs_in_parent_module.get(prec_node):
inputs.append(self.inputs_in_parent_module.get(prec_node))
elif in_module_returns and in_module_returns.get(self.onnx_name) \
and (not self._check_target_node_internal(prec_node)):
inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name)))
else:
inputs.append("unk_{}_{}".format(idx, prec_node))

if not self.matched_inputs and inputs is None:
raise ValueError("Unable to generate the code construct statement due to empty inputs.")

if self.matched_inputs:
inputs = self.matched_inputs
@@ -394,7 +396,7 @@ class NodeStruct:
"""
target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \
or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False

if target_nd_struct is None:
@@ -413,11 +415,11 @@ class NodeStruct:
@property
def precursor_nodes_names_external(self) -> list:
"""Return a list of external precursor nodes names."""
return [name for name in self.precursor_nodes_names \
if not self._check_target_node_internal(name)]
return [name for name in self.precursor_nodes_names
if not self._check_target_node_internal(name)]

@property
def successor_nodes_names_external(self) -> list:
"""Return a list of external successor nodes names."""
return [name for name in self.successor_nodes_names \
if not self._check_target_node_internal(name)]
return [name for name in self.successor_nodes_names
if not self._check_target_node_internal(name)]

+ 2
- 16
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module
from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
@@ -283,7 +283,7 @@ class HierarchicalTree(Tree):
Returns:
Dict, codes.
"""
code_blocks = [self._get_imported_module()]
code_blocks = [get_imported_module()]
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)

for depth in depths:
@@ -741,17 +741,3 @@ class HierarchicalTree(Tree):
"""Adjust tree structure to generate source code."""
self.sub_graph_merging()
self.update_hierarchical_order()

@staticmethod
def _get_imported_module():
"""
Generate imported module header.

Returns:
str, imported module.
"""
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

Loading…
Cancel
Save