| @@ -39,11 +39,13 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| MODEL_NOT_SUPPORT = 8 | MODEL_NOT_SUPPORT = 8 | ||||
| SCRIPT_GENERATE_FAIL = 9 | SCRIPT_GENERATE_FAIL = 9 | ||||
| REPORT_GENERATE_FAIL = 10 | REPORT_GENERATE_FAIL = 10 | ||||
| NODE_CONVERSION_ERROR = 11 | |||||
| BASE_CONVERTER_FAIL = 000 | BASE_CONVERTER_FAIL = 000 | ||||
| GRAPH_INIT_FAIL = 100 | GRAPH_INIT_FAIL = 100 | ||||
| TREE_CREATE_FAIL = 200 | TREE_CREATE_FAIL = 200 | ||||
| SOURCE_FILES_SAVE_FAIL = 300 | SOURCE_FILES_SAVE_FAIL = 300 | ||||
| GENERATOR_FAIL = 400 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -163,6 +165,7 @@ class MindConverterException(Exception): | |||||
| class BaseConverterFail(MindConverterException): | class BaseConverterFail(MindConverterException): | ||||
| """Base converter failed.""" | """Base converter failed.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @@ -195,6 +198,7 @@ class BaseConverterFail(MindConverterException): | |||||
| class UnknownModel(MindConverterException): | class UnknownModel(MindConverterException): | ||||
| """The unknown model error.""" | """The unknown model error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @@ -202,6 +206,7 @@ class UnknownModel(MindConverterException): | |||||
| class GraphInitFail(MindConverterException): | class GraphInitFail(MindConverterException): | ||||
| """The graph init fail error.""" | """The graph init fail error.""" | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | ||||
| user_msg=kwargs.get('msg', '')) | user_msg=kwargs.get('msg', '')) | ||||
| @@ -230,6 +235,7 @@ class GraphInitFail(MindConverterException): | |||||
| class TreeCreateFail(MindConverterException): | class TreeCreateFail(MindConverterException): | ||||
| """The tree create fail.""" | """The tree create fail.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @@ -254,6 +260,7 @@ class TreeCreateFail(MindConverterException): | |||||
| class SourceFilesSaveFail(MindConverterException): | class SourceFilesSaveFail(MindConverterException): | ||||
| """The source files save fail error.""" | """The source files save fail error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @@ -280,6 +287,7 @@ class SourceFilesSaveFail(MindConverterException): | |||||
| class ModelNotSupport(MindConverterException): | class ModelNotSupport(MindConverterException): | ||||
| """The model not support error.""" | """The model not support error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -338,6 +346,7 @@ class ModelNotSupport(MindConverterException): | |||||
| class NodeInputMissing(MindConverterException): | class NodeInputMissing(MindConverterException): | ||||
| """The node input missing error.""" | """The node input missing error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -346,6 +355,7 @@ class NodeInputMissing(MindConverterException): | |||||
| class TreeNodeInsertFail(MindConverterException): | class TreeNodeInsertFail(MindConverterException): | ||||
| """The tree node create fail error.""" | """The tree node create fail error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -380,6 +390,7 @@ class TreeNodeInsertFail(MindConverterException): | |||||
| class NodeInputTypeNotSupport(MindConverterException): | class NodeInputTypeNotSupport(MindConverterException): | ||||
| """The node input type NOT support error.""" | """The node input type NOT support error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -388,6 +399,7 @@ class NodeInputTypeNotSupport(MindConverterException): | |||||
| class ScriptGenerateFail(MindConverterException): | class ScriptGenerateFail(MindConverterException): | ||||
| """The script generate fail error.""" | """The script generate fail error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -421,6 +433,7 @@ class ScriptGenerateFail(MindConverterException): | |||||
| class ReportGenerateFail(MindConverterException): | class ReportGenerateFail(MindConverterException): | ||||
| """The report generate fail error.""" | """The report generate fail error.""" | ||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| @@ -448,3 +461,19 @@ class ReportGenerateFail(MindConverterException): | |||||
| return output | return output | ||||
| return _f | return _f | ||||
| return decorator | return decorator | ||||
| class GeneratorFail(MindConverterException): | |||||
| """The Generator fail error.""" | |||||
| def __init__(self, msg): | |||||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | |||||
| except_source = (ValueError, | |||||
| TypeError, | |||||
| cls) | |||||
| return except_source | |||||
| @@ -75,3 +75,17 @@ class InputType(Enum): | |||||
| class FrameworkType(Enum): | class FrameworkType(Enum): | ||||
| PYTORCH = 0 | PYTORCH = 0 | ||||
| TENSORFLOW = 1 | TENSORFLOW = 1 | ||||
| def get_imported_module(): | |||||
| """ | |||||
| Generate imported module header. | |||||
| Returns: | |||||
| str, imported module. | |||||
| """ | |||||
| return f"import numpy as np{NEW_LINE}" \ | |||||
| f"import mindspore{NEW_LINE}" \ | |||||
| f"from mindspore import nn{NEW_LINE}" \ | |||||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||||
| @@ -20,7 +20,8 @@ from importlib import import_module | |||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||||
| save_code_file_and_report | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | ||||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| @@ -29,6 +30,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreat | |||||
| BaseConverterFail, UnknownModel | BaseConverterFail, UnknownModel | ||||
| from mindinsight.utils.exceptions import ParamMissError | from mindinsight.utils.exceptions import ParamMissError | ||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| os.umask(permissions << 3 | permissions) | os.umask(permissions << 3 | permissions) | ||||
| @@ -194,23 +196,18 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| """ | """ | ||||
| third_party_graph_module = import_module( | third_party_graph_module = import_module( | ||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | ||||
| hierarchical_tree_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | ||||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||||
| batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), | |||||
| "batch_add_nodes") | |||||
| # Close unnecessary log. | # Close unnecessary log. | ||||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | ||||
| input_nodes=input_nodes, output_nodes=output_nodes) | input_nodes=input_nodes, output_nodes=output_nodes) | ||||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||||
| model_name=model_name, | |||||
| report_folder=report_folder, | |||||
| scope_name_map=scope_name_map) | |||||
| code_fragments = generator_inst.generate() | |||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||||
| @BaseConverterFail.check_except("Failed to start base converter.") | @BaseConverterFail.check_except("Failed to start base converter.") | ||||
| @@ -16,23 +16,17 @@ | |||||
| import copy | import copy | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from yapf.yapflib.yapf_api import FormatCode | |||||
| from .scope_utils import Scope | from .scope_utils import Scope | ||||
| from .node_struct import NodeStruct | from .node_struct import NodeStruct | ||||
| from .module_struct import ModuleStruct | from .module_struct import ModuleStruct | ||||
| from .args_translator import ArgsTranslationHelper | from .args_translator import ArgsTranslationHelper | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ...common.exceptions import GeneratorFail | |||||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | ||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT | |||||
| class Singleton(type): | |||||
| """Metaclass to make the generator to be single instance.""" | |||||
| _instances = {} | |||||
| def __call__(cls, *args, **kwargs): | |||||
| if cls not in cls._instances: | |||||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||||
| return cls._instances[cls] | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||||
| from ..report_generator import ReportGenerator | |||||
| class CodeStruct: | class CodeStruct: | ||||
| @@ -88,14 +82,9 @@ class CodeStruct: | |||||
| repeated_submodules (dict): The dict contains all submodules which use repeatedly. | repeated_submodules (dict): The dict contains all submodules which use repeatedly. | ||||
| Can get this dict from generator. | Can get this dict from generator. | ||||
| """ | """ | ||||
| # Define tmp var for code generation. | |||||
| opt_var_name_records = dict() # now only support multiple outputs within same scope. | |||||
| return_value_records = dict() # save returned values for successor nodes/modules use. | |||||
| # Define Module header code line below | # Define Module header code line below | ||||
| if md_struct.pattern_id != -1: | |||||
| class_name = f"Module{md_struct.pattern_id}" | |||||
| else: | |||||
| class_name = "Model" | |||||
| class_name = md_struct.class_name | |||||
| # define a class declaration | # define a class declaration | ||||
| self.new_line = f"class {class_name}(nn.Cell):" | self.new_line = f"class {class_name}(nn.Cell):" | ||||
| @@ -108,36 +97,18 @@ class CodeStruct: | |||||
| for formal in md_struct.args_translator.formal_args.keys(): | for formal in md_struct.args_translator.formal_args.keys(): | ||||
| module_def_args.append(formal) | module_def_args.append(formal) | ||||
| # Collect extra inputs and outputs | |||||
| # For code line in init & construct blocks | # For code line in init & construct blocks | ||||
| init_lines = list() | init_lines = list() | ||||
| cons_lines = list() | cons_lines = list() | ||||
| for (_, struct) in md_struct.get_generate_order(): | for (_, struct) in md_struct.get_generate_order(): | ||||
| if isinstance(struct, NodeStruct): # Generate code line for Node. | if isinstance(struct, NodeStruct): # Generate code line for Node. | ||||
| code_line_init = struct.code_line_in_init() | code_line_init = struct.code_line_in_init() | ||||
| code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) | |||||
| code_line_construct = struct.code_line_in_construct() | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | ||||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | ||||
| # add extra tensor | # add extra tensor | ||||
| if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | ||||
| code_extra_tensor = struct.add_extra_tensor() | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") | |||||
| # record opt_var_name for succ nodes input in same scope. | |||||
| target_onnx_name = struct.graph_node_ref.successor_nodes | |||||
| for name in target_onnx_name: | |||||
| if opt_var_name_records.get(name): | |||||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||||
| else: | |||||
| opt_var_name_records[name] = [code_line_construct[0]] | |||||
| if struct.successor_nodes_names_external: | |||||
| for ret_user in struct.successor_nodes_names_external: | |||||
| if return_value_records.get(ret_user) is not None: | |||||
| return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) | |||||
| else: | |||||
| return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}") | |||||
| elif isinstance(struct, ModuleStruct): | elif isinstance(struct, ModuleStruct): | ||||
| # check if this instance generated CodeStruct | # check if this instance generated CodeStruct | ||||
| @@ -149,22 +120,6 @@ class CodeStruct: | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | ||||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | ||||
| # record opt_var_name for succ nodes input in same scope. | |||||
| target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes | |||||
| for name in target_onnx_name: | |||||
| if opt_var_name_records.get(name): | |||||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||||
| else: | |||||
| opt_var_name_records[name] = [code_line_construct[0]] | |||||
| # record submodule's local return map for following nodes / submodules use | |||||
| if struct.external_successor_local_returns_map: | |||||
| for ret_user, _ in struct.external_successor_local_returns_map.items(): | |||||
| if return_value_records.get(ret_user) is not None: | |||||
| # mulitple returns of a node may need modifiy the index. | |||||
| return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) | |||||
| else: | |||||
| return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] | |||||
| else: | else: | ||||
| raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") | raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") | ||||
| @@ -183,8 +138,7 @@ class CodeStruct: | |||||
| # define returns | # define returns | ||||
| returns = [] | returns = [] | ||||
| if md_struct.external_successor_local_returns_map: | if md_struct.external_successor_local_returns_map: | ||||
| ret = list(md_struct.external_successor_local_returns_map.values()) | |||||
| for r in ret: | |||||
| for r in list(md_struct.external_successor_local_returns_map.values()): | |||||
| if isinstance(r, tuple): # results return with index nth output | if isinstance(r, tuple): # results return with index nth output | ||||
| returns.append(r[0]) | returns.append(r[0]) | ||||
| else: | else: | ||||
| @@ -197,7 +151,7 @@ class CodeStruct: | |||||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | ||||
| class Generator(metaclass=Singleton): | |||||
| class Generator: | |||||
| """The generator controls all routines of code generation.""" | """The generator controls all routines of code generation.""" | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -217,6 +171,7 @@ class Generator(metaclass=Singleton): | |||||
| self._global_context.node_struct_collections = self._node_struct_collections | self._global_context.node_struct_collections = self._node_struct_collections | ||||
| self._repeated_submodules = set() | self._repeated_submodules = set() | ||||
| @GeneratorFail.check_except("Generator occurs an error when forming base submodules.") | |||||
| def _form_bottom_submodule(self): | def _form_bottom_submodule(self): | ||||
| """Form the basic submodules, which only contains nodes.""" | """Form the basic submodules, which only contains nodes.""" | ||||
| # Form module map | # Form module map | ||||
| @@ -375,7 +330,7 @@ class Generator(metaclass=Singleton): | |||||
| if scope_path_str == '[]': | if scope_path_str == '[]': | ||||
| continue # is main module, skip | continue # is main module, skip | ||||
| if md_struct.scope_depth != depth: | if md_struct.scope_depth != depth: | ||||
| continue # skip all submodules not at current depth | |||||
| continue # skip all submodules not at current depth | |||||
| md_struct_scope = copy.deepcopy(md_struct.identifier) | md_struct_scope = copy.deepcopy(md_struct.identifier) | ||||
| md_struct_scope.pop() | md_struct_scope.pop() | ||||
| parent_scope = md_struct_scope | parent_scope = md_struct_scope | ||||
| @@ -396,6 +351,7 @@ class Generator(metaclass=Singleton): | |||||
| self._global_context.add_module_struct(sub.pattern_id, sub) | self._global_context.add_module_struct(sub.pattern_id, sub) | ||||
| depth -= 1 | depth -= 1 | ||||
| @GeneratorFail.check_except("Generator occurs an error when building modules.") | |||||
| def _recursive_form_module(self): | def _recursive_form_module(self): | ||||
| """Main routine in generator to build modules from bottom to top.""" | """Main routine in generator to build modules from bottom to top.""" | ||||
| # 1. List repeated submodules | # 1. List repeated submodules | ||||
| @@ -518,6 +474,7 @@ class Generator(metaclass=Singleton): | |||||
| """Return all ModuleStructs in this model.""" | """Return all ModuleStructs in this model.""" | ||||
| return self._module_struct_collections | return self._module_struct_collections | ||||
| @GeneratorFail.check_except("Generator occurs an error when generating code statements.") | |||||
| def generate(self): | def generate(self): | ||||
| """ | """ | ||||
| Generate the final script file. | Generate the final script file. | ||||
| @@ -527,8 +484,22 @@ class Generator(metaclass=Singleton): | |||||
| """ | """ | ||||
| self._form_bottom_submodule() | self._form_bottom_submodule() | ||||
| self._recursive_form_module() | self._recursive_form_module() | ||||
| code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||||
| return code.code_line_list | |||||
| CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||||
| outputs = [get_imported_module()] | |||||
| for code_struct in self._global_context.code_structs.values(): | |||||
| for line in code_struct.code_line_list: | |||||
| outputs.append(line) | |||||
| formatted_code, _ = FormatCode("\n".join(outputs), | |||||
| style_config=CodeFormatConfig.PEP8.value) | |||||
| report_generator = ReportGenerator() | |||||
| report = report_generator.gen_report(formatted_code) | |||||
| del self._global_context | |||||
| return {"model": (formatted_code, report)} | |||||
| def get_node_struct(self, node_identifier): | def get_node_struct(self, node_identifier): | ||||
| """ | """ | ||||
| @@ -554,28 +525,6 @@ class Generator(metaclass=Singleton): | |||||
| """ | """ | ||||
| return self._module_struct_collections.get(module_identifier, None) | return self._module_struct_collections.get(module_identifier, None) | ||||
| def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): | |||||
| """ | |||||
| Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. | |||||
| Args: | |||||
| pattern_id (int): The pattern id the returned ModuleSturct is. | |||||
| under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. | |||||
| Returns: | |||||
| list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. | |||||
| """ | |||||
| if not pattern_id: | |||||
| raise ValueError("pattern_id is necessary to get the module struct.") | |||||
| if not under_parent_pattern_id: | |||||
| raise ValueError("under_parent_pattern_id is necessary to get the module struct.") | |||||
| ret = [] | |||||
| md_list = self._global_context.module_structs.get(pattern_id) | |||||
| for md in md_list: | |||||
| if md.parent_id == under_parent_pattern_id: | |||||
| ret.append(md) | |||||
| return ret | |||||
| def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): | def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): | ||||
| """ | """ | ||||
| Return a list of args translators which belongs to given module structs. | Return a list of args translators which belongs to given module structs. | ||||
| @@ -201,7 +201,7 @@ class ModuleStruct: | |||||
| def init_args_translator(self): | def init_args_translator(self): | ||||
| """Initialize the Args Translator for the module.""" | """Initialize the Args Translator for the module.""" | ||||
| var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) | |||||
| var_name = self.ms_var_name | |||||
| self._args_translator = ArgsTranslation(None, var_name, None) | self._args_translator = ArgsTranslation(None, var_name, None) | ||||
| def update_module_fragment(self): | def update_module_fragment(self): | ||||
| @@ -455,14 +455,18 @@ class ModuleStruct: | |||||
| """Return the class name for generating code of this module.""" | """Return the class name for generating code of this module.""" | ||||
| if self.pattern_id == -1: | if self.pattern_id == -1: | ||||
| return "Model" | return "Model" | ||||
| return "Module{}".format(self.pattern_id) | |||||
| if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||||
| class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) | |||||
| else: | |||||
| class_name = "Module{}".format(self.pattern_id) | |||||
| return class_name | |||||
| @property | @property | ||||
| def ms_var_name(self) -> str: | def ms_var_name(self) -> str: | ||||
| """Return the variable name for generated code statement of this module.""" | """Return the variable name for generated code statement of this module.""" | ||||
| if self.pattern_id == -1: | if self.pattern_id == -1: | ||||
| return "Model" | return "Model" | ||||
| return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() | |||||
| return f"{self.class_name}_{self.pattern_uid}".lower() | |||||
| @property | @property | ||||
| def ms_opt_var_name(self) -> str: | def ms_opt_var_name(self) -> str: | ||||
| @@ -22,6 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ..constant import InputType | from ..constant import InputType | ||||
| from ...common.exceptions import GeneratorFail | |||||
| class NodeStruct: | class NodeStruct: | ||||
| @@ -47,22 +48,32 @@ class NodeStruct: | |||||
| self.node_type = None | self.node_type = None | ||||
| self.onnx_name = None | self.onnx_name = None | ||||
| self.onnx_op = None | self.onnx_op = None | ||||
| self.graph_node_ref = None # Our defined GraphNode | |||||
| self.graph_node_ref = None | |||||
| self.scope_name = None | self.scope_name = None | ||||
| self.ms_var_name = None | self.ms_var_name = None | ||||
| self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) | |||||
| self.ms_opt_var_name = None | |||||
| self.ms_op = None | self.ms_op = None | ||||
| self.ready_to_generate = False | self.ready_to_generate = False | ||||
| self.ms_params = dict() # converted params from mapper | |||||
| # Define attributes converted from mapper | |||||
| self.ms_params = dict() | |||||
| self.ms_settings = dict() | self.ms_settings = dict() | ||||
| self.ms_weights = dict() | self.ms_weights = dict() | ||||
| self.ms_inputs = OrderedDict() | self.ms_inputs = OrderedDict() | ||||
| self.scope = None # Defined Scope class | |||||
| self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use | |||||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||||
| # Defined Scope class | |||||
| self.scope = None | |||||
| # Define attributes used for code generation | |||||
| # key is prec_node_name, value is x; For code line use | |||||
| self.inputs_in_construct_header = OrderedDict() | |||||
| # key is prec_node_name, value is its closet opt_var_name | |||||
| self.inputs_in_parent_module = OrderedDict() | |||||
| # Matched inputs will can be directly used by code line generation | |||||
| self.matched_inputs = list() | |||||
| # initialize funcs. | # initialize funcs. | ||||
| for arg in args: | for arg in args: | ||||
| @@ -135,6 +146,7 @@ class NodeStruct: | |||||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | ||||
| self.scope = Scope(parsed_scope) | self.scope = Scope(parsed_scope) | ||||
| @GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.") | |||||
| def init_args_translator(self, translated_args: list): | def init_args_translator(self, translated_args: list): | ||||
| """ | """ | ||||
| Initialize the ArgsTranslator for each Node. | Initialize the ArgsTranslator for each Node. | ||||
| @@ -158,6 +170,7 @@ class NodeStruct: | |||||
| self.ms_op]): | self.ms_op]): | ||||
| self.ready_to_generate = True | self.ready_to_generate = True | ||||
| @GeneratorFail.check_except("Generator occurs an error when creating node struct.") | |||||
| def update(self, arg, force_ready=False): | def update(self, arg, force_ready=False): | ||||
| """ | """ | ||||
| Pass Node info. to generator NodeStruct. | Pass Node info. to generator NodeStruct. | ||||
| @@ -314,23 +327,12 @@ class NodeStruct: | |||||
| return input_name_to_use | return input_name_to_use | ||||
| return None | return None | ||||
| def code_line_in_construct(self, inputs=None, in_module_returns=None): | |||||
| def code_line_in_construct(self, inputs=None): | |||||
| """Construct line of code in module construct block. """ | """Construct line of code in module construct block. """ | ||||
| left = self.ms_opt_var_name | left = self.ms_opt_var_name | ||||
| if inputs is None: | |||||
| inputs = [] | |||||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||||
| if self.inputs_in_construct_header.get(prec_node): | |||||
| inputs.append(self.inputs_in_construct_header.get(prec_node)) | |||||
| elif self._check_target_node_internal(prec_node): | |||||
| inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) | |||||
| elif self.inputs_in_parent_module.get(prec_node): | |||||
| inputs.append(self.inputs_in_parent_module.get(prec_node)) | |||||
| elif in_module_returns and in_module_returns.get(self.onnx_name) \ | |||||
| and (not self._check_target_node_internal(prec_node)): | |||||
| inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) | |||||
| else: | |||||
| inputs.append("unk_{}_{}".format(idx, prec_node)) | |||||
| if not self.matched_inputs and inputs is None: | |||||
| raise ValueError("Unable to generate the code construct statement due to empty inputs.") | |||||
| if self.matched_inputs: | if self.matched_inputs: | ||||
| inputs = self.matched_inputs | inputs = self.matched_inputs | ||||
| @@ -394,7 +396,7 @@ class NodeStruct: | |||||
| """ | """ | ||||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | ||||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | ||||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||||
| return False | return False | ||||
| if target_nd_struct is None: | if target_nd_struct is None: | ||||
| @@ -413,11 +415,11 @@ class NodeStruct: | |||||
| @property | @property | ||||
| def precursor_nodes_names_external(self) -> list: | def precursor_nodes_names_external(self) -> list: | ||||
| """Return a list of external precursor nodes names.""" | """Return a list of external precursor nodes names.""" | ||||
| return [name for name in self.precursor_nodes_names \ | |||||
| if not self._check_target_node_internal(name)] | |||||
| return [name for name in self.precursor_nodes_names | |||||
| if not self._check_target_node_internal(name)] | |||||
| @property | @property | ||||
| def successor_nodes_names_external(self) -> list: | def successor_nodes_names_external(self) -> list: | ||||
| """Return a list of external successor nodes names.""" | """Return a list of external successor nodes names.""" | ||||
| return [name for name in self.successor_nodes_names \ | |||||
| if not self._check_target_node_internal(name)] | |||||
| return [name for name in self.successor_nodes_names | |||||
| if not self._check_target_node_internal(name)] | |||||
| @@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..constant import SEPARATOR_IN_SCOPE | |||||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module | |||||
| from ..constant import CodeFormatConfig | from ..constant import CodeFormatConfig | ||||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | ||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| @@ -283,7 +283,7 @@ class HierarchicalTree(Tree): | |||||
| Returns: | Returns: | ||||
| Dict, codes. | Dict, codes. | ||||
| """ | """ | ||||
| code_blocks = [self._get_imported_module()] | |||||
| code_blocks = [get_imported_module()] | |||||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | ||||
| for depth in depths: | for depth in depths: | ||||
| @@ -741,17 +741,3 @@ class HierarchicalTree(Tree): | |||||
| """Adjust tree structure to generate source code.""" | """Adjust tree structure to generate source code.""" | ||||
| self.sub_graph_merging() | self.sub_graph_merging() | ||||
| self.update_hierarchical_order() | self.update_hierarchical_order() | ||||
| @staticmethod | |||||
| def _get_imported_module(): | |||||
| """ | |||||
| Generate imported module header. | |||||
| Returns: | |||||
| str, imported module. | |||||
| """ | |||||
| return f"import numpy as np{NEW_LINE}" \ | |||||
| f"import mindspore{NEW_LINE}" \ | |||||
| f"from mindspore import nn{NEW_LINE}" \ | |||||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||||