Browse Source

Updates on merging generator to MindConverter, and unnecessary logic clean

tags/v1.1.0
liangtianshu 5 years ago
parent
commit
666a0d35a3
7 changed files with 120 additions and 139 deletions
  1. +29
    -0
      mindinsight/mindconverter/common/exceptions.py
  2. +14
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +8
    -11
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +31
    -82
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  5. +7
    -3
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  6. +29
    -27
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  7. +2
    -16
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py

+ 29
- 0
mindinsight/mindconverter/common/exceptions.py View File

@@ -39,11 +39,13 @@ class ConverterErrors(ScriptConverterErrors):
MODEL_NOT_SUPPORT = 8 MODEL_NOT_SUPPORT = 8
SCRIPT_GENERATE_FAIL = 9 SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10 REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11


BASE_CONVERTER_FAIL = 000 BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100 GRAPH_INIT_FAIL = 100
TREE_CREATE_FAIL = 200 TREE_CREATE_FAIL = 200
SOURCE_FILES_SAVE_FAIL = 300 SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400




class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
@@ -163,6 +165,7 @@ class MindConverterException(Exception):


class BaseConverterFail(MindConverterException): class BaseConverterFail(MindConverterException):
"""Base converter failed.""" """Base converter failed."""

def __init__(self, msg): def __init__(self, msg):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg) user_msg=msg)
@@ -195,6 +198,7 @@ class BaseConverterFail(MindConverterException):


class UnknownModel(MindConverterException): class UnknownModel(MindConverterException):
"""The unknown model error.""" """The unknown model error."""

def __init__(self, msg): def __init__(self, msg):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg) user_msg=msg)
@@ -202,6 +206,7 @@ class UnknownModel(MindConverterException):


class GraphInitFail(MindConverterException): class GraphInitFail(MindConverterException):
"""The graph init fail error.""" """The graph init fail error."""

def __init__(self, **kwargs): def __init__(self, **kwargs):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', '')) user_msg=kwargs.get('msg', ''))
@@ -230,6 +235,7 @@ class GraphInitFail(MindConverterException):


class TreeCreateFail(MindConverterException): class TreeCreateFail(MindConverterException):
"""The tree create fail.""" """The tree create fail."""

def __init__(self, msg): def __init__(self, msg):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg) user_msg=msg)
@@ -254,6 +260,7 @@ class TreeCreateFail(MindConverterException):


class SourceFilesSaveFail(MindConverterException): class SourceFilesSaveFail(MindConverterException):
"""The source files save fail error.""" """The source files save fail error."""

def __init__(self, msg): def __init__(self, msg):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg) user_msg=msg)
@@ -280,6 +287,7 @@ class SourceFilesSaveFail(MindConverterException):


class ModelNotSupport(MindConverterException): class ModelNotSupport(MindConverterException):
"""The model not support error.""" """The model not support error."""

def __init__(self, msg): def __init__(self, msg):
super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
user_msg=msg, user_msg=msg,
@@ -338,6 +346,7 @@ class ModelNotSupport(MindConverterException):


class NodeInputMissing(MindConverterException): class NodeInputMissing(MindConverterException):
"""The node input missing error.""" """The node input missing error."""

def __init__(self, msg): def __init__(self, msg):
super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING,
user_msg=msg, user_msg=msg,
@@ -346,6 +355,7 @@ class NodeInputMissing(MindConverterException):


class TreeNodeInsertFail(MindConverterException): class TreeNodeInsertFail(MindConverterException):
"""The tree node create fail error.""" """The tree node create fail error."""

def __init__(self, msg): def __init__(self, msg):
super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL,
user_msg=msg, user_msg=msg,
@@ -380,6 +390,7 @@ class TreeNodeInsertFail(MindConverterException):


class NodeInputTypeNotSupport(MindConverterException): class NodeInputTypeNotSupport(MindConverterException):
"""The node input type NOT support error.""" """The node input type NOT support error."""

def __init__(self, msg): def __init__(self, msg):
super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
user_msg=msg, user_msg=msg,
@@ -388,6 +399,7 @@ class NodeInputTypeNotSupport(MindConverterException):


class ScriptGenerateFail(MindConverterException): class ScriptGenerateFail(MindConverterException):
"""The script generate fail error.""" """The script generate fail error."""

def __init__(self, msg): def __init__(self, msg):
super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL,
user_msg=msg, user_msg=msg,
@@ -421,6 +433,7 @@ class ScriptGenerateFail(MindConverterException):


class ReportGenerateFail(MindConverterException): class ReportGenerateFail(MindConverterException):
"""The report generate fail error.""" """The report generate fail error."""

def __init__(self, msg): def __init__(self, msg):
super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL,
user_msg=msg, user_msg=msg,
@@ -448,3 +461,19 @@ class ReportGenerateFail(MindConverterException):
return output return output
return _f return _f
return decorator return decorator


class GeneratorFail(MindConverterException):
"""The Generator fail error."""

def __init__(self, msg):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value)
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = (ValueError,
TypeError,
cls)
return except_source

+ 14
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -75,3 +75,17 @@ class InputType(Enum):
class FrameworkType(Enum): class FrameworkType(Enum):
PYTORCH = 0 PYTORCH = 0
TENSORFLOW = 1 TENSORFLOW = 1


def get_imported_module():
"""
Generate imported module header.

Returns:
str, imported module.
"""
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 8
- 11
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -20,7 +20,8 @@ from importlib import import_module
from importlib.util import find_spec from importlib.util import find_spec


import mindinsight import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
@@ -29,6 +30,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreat
BaseConverterFail, UnknownModel BaseConverterFail, UnknownModel
from mindinsight.utils.exceptions import ParamMissError from mindinsight.utils.exceptions import ParamMissError



permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)


@@ -194,23 +196,18 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
""" """
third_party_graph_module = import_module( third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph') 'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'),
"batch_add_nodes")
# Close unnecessary log. # Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes) input_nodes=input_nodes, output_nodes=output_nodes)

hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj)

generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path) model_name = _extract_model_name(graph_path)
hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder,
scope_name_map=scope_name_map)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)




@BaseConverterFail.check_except("Failed to start base converter.") @BaseConverterFail.check_except("Failed to start base converter.")


+ 31
- 82
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -16,23 +16,17 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict


from yapf.yapflib.yapf_api import FormatCode

from .scope_utils import Scope from .scope_utils import Scope
from .node_struct import NodeStruct from .node_struct import NodeStruct
from .module_struct import ModuleStruct from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorFail
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT


class Singleton(type):
"""Metaclass to make the generator to be single instance."""
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator




class CodeStruct: class CodeStruct:
@@ -88,14 +82,9 @@ class CodeStruct:
repeated_submodules (dict): The dict contains all submodules which use repeatedly. repeated_submodules (dict): The dict contains all submodules which use repeatedly.
Can get this dict from generator. Can get this dict from generator.
""" """
# Define tmp var for code generation.
opt_var_name_records = dict() # now only support multiple outputs within same scope.
return_value_records = dict() # save returned values for successor nodes/modules use.

# Define Module header code line below # Define Module header code line below
if md_struct.pattern_id != -1:
class_name = f"Module{md_struct.pattern_id}"
else:
class_name = "Model"
class_name = md_struct.class_name
# define a class declaration # define a class declaration
self.new_line = f"class {class_name}(nn.Cell):" self.new_line = f"class {class_name}(nn.Cell):"


@@ -108,36 +97,18 @@ class CodeStruct:
for formal in md_struct.args_translator.formal_args.keys(): for formal in md_struct.args_translator.formal_args.keys():
module_def_args.append(formal) module_def_args.append(formal)


# Collect extra inputs and outputs

# For code line in init & construct blocks # For code line in init & construct blocks
init_lines = list() init_lines = list()
cons_lines = list() cons_lines = list()
for (_, struct) in md_struct.get_generate_order(): for (_, struct) in md_struct.get_generate_order():
if isinstance(struct, NodeStruct): # Generate code line for Node. if isinstance(struct, NodeStruct): # Generate code line for Node.
code_line_init = struct.code_line_in_init() code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records)
code_line_construct = struct.code_line_in_construct()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")
# add extra tensor # add extra tensor
if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor:
code_extra_tensor = struct.add_extra_tensor()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}")

# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

if struct.successor_nodes_names_external:
for ret_user in struct.successor_nodes_names_external:
if return_value_records.get(ret_user) is not None:
return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])]
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}")


elif isinstance(struct, ModuleStruct): elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct # check if this instance generated CodeStruct
@@ -149,22 +120,6 @@ class CodeStruct:
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")


# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

# record submodule's local return map for following nodes / submodules use
if struct.external_successor_local_returns_map:
for ret_user, _ in struct.external_successor_local_returns_map.items():
if return_value_records.get(ret_user) is not None:
# mulitple returns of a node may need modifiy the index.
return_value_records[ret_user].append((struct.identifier, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])]
else: else:
raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.")


@@ -183,8 +138,7 @@ class CodeStruct:
# define returns # define returns
returns = [] returns = []
if md_struct.external_successor_local_returns_map: if md_struct.external_successor_local_returns_map:
ret = list(md_struct.external_successor_local_returns_map.values())
for r in ret:
for r in list(md_struct.external_successor_local_returns_map.values()):
if isinstance(r, tuple): # results return with index nth output if isinstance(r, tuple): # results return with index nth output
returns.append(r[0]) returns.append(r[0])
else: else:
@@ -197,7 +151,7 @@ class CodeStruct:
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self




class Generator(metaclass=Singleton):
class Generator:
"""The generator controls all routines of code generation.""" """The generator controls all routines of code generation."""


def __init__(self): def __init__(self):
@@ -217,6 +171,7 @@ class Generator(metaclass=Singleton):
self._global_context.node_struct_collections = self._node_struct_collections self._global_context.node_struct_collections = self._node_struct_collections
self._repeated_submodules = set() self._repeated_submodules = set()


@GeneratorFail.check_except("Generator occurs an error when forming base submodules.")
def _form_bottom_submodule(self): def _form_bottom_submodule(self):
"""Form the basic submodules, which only contains nodes.""" """Form the basic submodules, which only contains nodes."""
# Form module map # Form module map
@@ -375,7 +330,7 @@ class Generator(metaclass=Singleton):
if scope_path_str == '[]': if scope_path_str == '[]':
continue # is main module, skip continue # is main module, skip
if md_struct.scope_depth != depth: if md_struct.scope_depth != depth:
continue # skip all submodules not at current depth
continue # skip all submodules not at current depth
md_struct_scope = copy.deepcopy(md_struct.identifier) md_struct_scope = copy.deepcopy(md_struct.identifier)
md_struct_scope.pop() md_struct_scope.pop()
parent_scope = md_struct_scope parent_scope = md_struct_scope
@@ -396,6 +351,7 @@ class Generator(metaclass=Singleton):
self._global_context.add_module_struct(sub.pattern_id, sub) self._global_context.add_module_struct(sub.pattern_id, sub)
depth -= 1 depth -= 1


@GeneratorFail.check_except("Generator occurs an error when building modules.")
def _recursive_form_module(self): def _recursive_form_module(self):
"""Main routine in generator to build modules from bottom to top.""" """Main routine in generator to build modules from bottom to top."""
# 1. List repeated submodules # 1. List repeated submodules
@@ -518,6 +474,7 @@ class Generator(metaclass=Singleton):
"""Return all ModuleStructs in this model.""" """Return all ModuleStructs in this model."""
return self._module_struct_collections return self._module_struct_collections


@GeneratorFail.check_except("Generator occurs an error when generating code statements.")
def generate(self): def generate(self):
""" """
Generate the final script file. Generate the final script file.
@@ -527,8 +484,22 @@ class Generator(metaclass=Singleton):
""" """
self._form_bottom_submodule() self._form_bottom_submodule()
self._recursive_form_module() self._recursive_form_module()
code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)
return code.code_line_list
CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)

outputs = [get_imported_module()]

for code_struct in self._global_context.code_structs.values():
for line in code_struct.code_line_list:
outputs.append(line)

formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value)

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)
del self._global_context

return {"model": (formatted_code, report)}


def get_node_struct(self, node_identifier): def get_node_struct(self, node_identifier):
""" """
@@ -554,28 +525,6 @@ class Generator(metaclass=Singleton):
""" """
return self._module_struct_collections.get(module_identifier, None) return self._module_struct_collections.get(module_identifier, None)


def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id):
"""
Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern.

Args:
pattern_id (int): The pattern id the returned ModuleSturct is.
under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is.

Returns:
list, a list of MoudleStructs has the same pattern_id and the same parents' pattern.
"""
if not pattern_id:
raise ValueError("pattern_id is necessary to get the module struct.")
if not under_parent_pattern_id:
raise ValueError("under_parent_pattern_id is necessary to get the module struct.")
ret = []
md_list = self._global_context.module_structs.get(pattern_id)
for md in md_list:
if md.parent_id == under_parent_pattern_id:
ret.append(md)
return ret

def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False):
""" """
Return a list of args translators which belongs to given module structs. Return a list of args translators which belongs to given module structs.


+ 7
- 3
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -201,7 +201,7 @@ class ModuleStruct:


def init_args_translator(self): def init_args_translator(self):
"""Initialize the Args Translator for the module.""" """Initialize the Args Translator for the module."""
var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid)
var_name = self.ms_var_name
self._args_translator = ArgsTranslation(None, var_name, None) self._args_translator = ArgsTranslation(None, var_name, None)


def update_module_fragment(self): def update_module_fragment(self):
@@ -455,14 +455,18 @@ class ModuleStruct:
"""Return the class name for generating code of this module.""" """Return the class name for generating code of this module."""
if self.pattern_id == -1: if self.pattern_id == -1:
return "Model" return "Model"
return "Module{}".format(self.pattern_id)
if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id))
else:
class_name = "Module{}".format(self.pattern_id)
return class_name


@property @property
def ms_var_name(self) -> str: def ms_var_name(self) -> str:
"""Return the variable name for generated code statement of this module.""" """Return the variable name for generated code statement of this module."""
if self.pattern_id == -1: if self.pattern_id == -1:
return "Model" return "Model"
return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower()
return f"{self.class_name}_{self.pattern_uid}".lower()


@property @property
def ms_opt_var_name(self) -> str: def ms_opt_var_name(self) -> str:


+ 29
- 27
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -22,6 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..constant import InputType from ..constant import InputType
from ...common.exceptions import GeneratorFail




class NodeStruct: class NodeStruct:
@@ -47,22 +48,32 @@ class NodeStruct:
self.node_type = None self.node_type = None
self.onnx_name = None self.onnx_name = None
self.onnx_op = None self.onnx_op = None
self.graph_node_ref = None # Our defined GraphNode
self.graph_node_ref = None
self.scope_name = None self.scope_name = None
self.ms_var_name = None self.ms_var_name = None
self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...)
self.ms_opt_var_name = None
self.ms_op = None self.ms_op = None
self.ready_to_generate = False self.ready_to_generate = False


self.ms_params = dict() # converted params from mapper
# Define attributes converted from mapper
self.ms_params = dict()
self.ms_settings = dict() self.ms_settings = dict()
self.ms_weights = dict() self.ms_weights = dict()
self.ms_inputs = OrderedDict() self.ms_inputs = OrderedDict()


self.scope = None # Defined Scope class
self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name
self.matched_inputs = list() # Matched inputs will can be directly used by code line generation
# Defined Scope class
self.scope = None

# Define attributes used for code generation

# key is prec_node_name, value is x; For code line use
self.inputs_in_construct_header = OrderedDict()

# key is prec_node_name, value is its closet opt_var_name
self.inputs_in_parent_module = OrderedDict()

# Matched inputs will can be directly used by code line generation
self.matched_inputs = list()


# initialize funcs. # initialize funcs.
for arg in args: for arg in args:
@@ -135,6 +146,7 @@ class NodeStruct:
parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier)
self.scope = Scope(parsed_scope) self.scope = Scope(parsed_scope)


@GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.")
def init_args_translator(self, translated_args: list): def init_args_translator(self, translated_args: list):
""" """
Initialize the ArgsTranslator for each Node. Initialize the ArgsTranslator for each Node.
@@ -158,6 +170,7 @@ class NodeStruct:
self.ms_op]): self.ms_op]):
self.ready_to_generate = True self.ready_to_generate = True


@GeneratorFail.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False): def update(self, arg, force_ready=False):
""" """
Pass Node info. to generator NodeStruct. Pass Node info. to generator NodeStruct.
@@ -314,23 +327,12 @@ class NodeStruct:
return input_name_to_use return input_name_to_use
return None return None


def code_line_in_construct(self, inputs=None, in_module_returns=None):
def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block. """ """Construct line of code in module construct block. """
left = self.ms_opt_var_name left = self.ms_opt_var_name
if inputs is None:
inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
if self.inputs_in_construct_header.get(prec_node):
inputs.append(self.inputs_in_construct_header.get(prec_node))
elif self._check_target_node_internal(prec_node):
inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name)
elif self.inputs_in_parent_module.get(prec_node):
inputs.append(self.inputs_in_parent_module.get(prec_node))
elif in_module_returns and in_module_returns.get(self.onnx_name) \
and (not self._check_target_node_internal(prec_node)):
inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name)))
else:
inputs.append("unk_{}_{}".format(idx, prec_node))

if not self.matched_inputs and inputs is None:
raise ValueError("Unable to generate the code construct statement due to empty inputs.")


if self.matched_inputs: if self.matched_inputs:
inputs = self.matched_inputs inputs = self.matched_inputs
@@ -394,7 +396,7 @@ class NodeStruct:
""" """
target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \
or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False return False


if target_nd_struct is None: if target_nd_struct is None:
@@ -413,11 +415,11 @@ class NodeStruct:
@property @property
def precursor_nodes_names_external(self) -> list: def precursor_nodes_names_external(self) -> list:
"""Return a list of external precursor nodes names.""" """Return a list of external precursor nodes names."""
return [name for name in self.precursor_nodes_names \
if not self._check_target_node_internal(name)]
return [name for name in self.precursor_nodes_names
if not self._check_target_node_internal(name)]


@property @property
def successor_nodes_names_external(self) -> list: def successor_nodes_names_external(self) -> list:
"""Return a list of external successor nodes names.""" """Return a list of external successor nodes names."""
return [name for name in self.successor_nodes_names \
if not self._check_target_node_internal(name)]
return [name for name in self.successor_nodes_names
if not self._check_target_node_internal(name)]

+ 2
- 16
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module
from ..constant import CodeFormatConfig from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
@@ -283,7 +283,7 @@ class HierarchicalTree(Tree):
Returns: Returns:
Dict, codes. Dict, codes.
""" """
code_blocks = [self._get_imported_module()]
code_blocks = [get_imported_module()]
depths = sorted(list(self._hierarchical_order.keys()), reverse=True) depths = sorted(list(self._hierarchical_order.keys()), reverse=True)


for depth in depths: for depth in depths:
@@ -741,17 +741,3 @@ class HierarchicalTree(Tree):
"""Adjust tree structure to generate source code.""" """Adjust tree structure to generate source code."""
self.sub_graph_merging() self.sub_graph_merging()
self.update_hierarchical_order() self.update_hierarchical_order()

@staticmethod
def _get_imported_module():
"""
Generate imported module header.

Returns:
str, imported module.
"""
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

Loading…
Cancel
Save