| @@ -39,11 +39,13 @@ class ConverterErrors(ScriptConverterErrors): | |||
| MODEL_NOT_SUPPORT = 8 | |||
| SCRIPT_GENERATE_FAIL = 9 | |||
| REPORT_GENERATE_FAIL = 10 | |||
| NODE_CONVERSION_ERROR = 11 | |||
| BASE_CONVERTER_FAIL = 000 | |||
| GRAPH_INIT_FAIL = 100 | |||
| TREE_CREATE_FAIL = 200 | |||
| SOURCE_FILES_SAVE_FAIL = 300 | |||
| GENERATOR_FAIL = 400 | |||
| class ScriptNotSupport(MindInsightException): | |||
| @@ -163,6 +165,7 @@ class MindConverterException(Exception): | |||
| class BaseConverterFail(MindConverterException): | |||
| """Base converter failed.""" | |||
| def __init__(self, msg): | |||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | |||
| user_msg=msg) | |||
| @@ -195,6 +198,7 @@ class BaseConverterFail(MindConverterException): | |||
| class UnknownModel(MindConverterException): | |||
| """The unknown model error.""" | |||
| def __init__(self, msg): | |||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | |||
| user_msg=msg) | |||
| @@ -202,6 +206,7 @@ class UnknownModel(MindConverterException): | |||
| class GraphInitFail(MindConverterException): | |||
| """The graph init fail error.""" | |||
| def __init__(self, **kwargs): | |||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | |||
| user_msg=kwargs.get('msg', '')) | |||
| @@ -230,6 +235,7 @@ class GraphInitFail(MindConverterException): | |||
| class TreeCreateFail(MindConverterException): | |||
| """The tree create fail.""" | |||
| def __init__(self, msg): | |||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | |||
| user_msg=msg) | |||
| @@ -254,6 +260,7 @@ class TreeCreateFail(MindConverterException): | |||
| class SourceFilesSaveFail(MindConverterException): | |||
| """The source files save fail error.""" | |||
| def __init__(self, msg): | |||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | |||
| user_msg=msg) | |||
| @@ -280,6 +287,7 @@ class SourceFilesSaveFail(MindConverterException): | |||
| class ModelNotSupport(MindConverterException): | |||
| """The model not support error.""" | |||
| def __init__(self, msg): | |||
| super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||
| user_msg=msg, | |||
| @@ -338,6 +346,7 @@ class ModelNotSupport(MindConverterException): | |||
| class NodeInputMissing(MindConverterException): | |||
| """The node input missing error.""" | |||
| def __init__(self, msg): | |||
| super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, | |||
| user_msg=msg, | |||
| @@ -346,6 +355,7 @@ class NodeInputMissing(MindConverterException): | |||
| class TreeNodeInsertFail(MindConverterException): | |||
| """The tree node create fail error.""" | |||
| def __init__(self, msg): | |||
| super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, | |||
| user_msg=msg, | |||
| @@ -380,6 +390,7 @@ class TreeNodeInsertFail(MindConverterException): | |||
| class NodeInputTypeNotSupport(MindConverterException): | |||
| """The node input type NOT support error.""" | |||
| def __init__(self, msg): | |||
| super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||
| user_msg=msg, | |||
| @@ -388,6 +399,7 @@ class NodeInputTypeNotSupport(MindConverterException): | |||
| class ScriptGenerateFail(MindConverterException): | |||
| """The script generate fail error.""" | |||
| def __init__(self, msg): | |||
| super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, | |||
| user_msg=msg, | |||
| @@ -421,6 +433,7 @@ class ScriptGenerateFail(MindConverterException): | |||
| class ReportGenerateFail(MindConverterException): | |||
| """The report generate fail error.""" | |||
| def __init__(self, msg): | |||
| super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, | |||
| user_msg=msg, | |||
| @@ -448,3 +461,19 @@ class ReportGenerateFail(MindConverterException): | |||
| return output | |||
| return _f | |||
| return decorator | |||
| class GeneratorFail(MindConverterException): | |||
| """The Generator fail error.""" | |||
| def __init__(self, msg): | |||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | |||
| user_msg=msg, | |||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| except_source = (ValueError, | |||
| TypeError, | |||
| cls) | |||
| return except_source | |||
| @@ -75,3 +75,17 @@ class InputType(Enum): | |||
| class FrameworkType(Enum): | |||
| PYTORCH = 0 | |||
| TENSORFLOW = 1 | |||
| def get_imported_module(): | |||
| """ | |||
| Generate imported module header. | |||
| Returns: | |||
| str, imported module. | |||
| """ | |||
| return f"import numpy as np{NEW_LINE}" \ | |||
| f"import mindspore{NEW_LINE}" \ | |||
| f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||
| @@ -20,7 +20,8 @@ from importlib import import_module | |||
| from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||
| save_code_file_and_report | |||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| @@ -29,6 +30,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreat | |||
| BaseConverterFail, UnknownModel | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -194,23 +196,18 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| """ | |||
| third_party_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||
| hierarchical_tree_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||
| batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), | |||
| "batch_add_nodes") | |||
| # Close unnecessary log. | |||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||
| model_name = _extract_model_name(graph_path) | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| model_name=model_name, | |||
| report_folder=report_folder, | |||
| scope_name_map=scope_name_map) | |||
| code_fragments = generator_inst.generate() | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| @BaseConverterFail.check_except("Failed to start base converter.") | |||
| @@ -16,23 +16,17 @@ | |||
| import copy | |||
| from collections import OrderedDict | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from .scope_utils import Scope | |||
| from .node_struct import NodeStruct | |||
| from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ...common.exceptions import GeneratorFail | |||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT | |||
| class Singleton(type): | |||
| """Metaclass to make the generator to be single instance.""" | |||
| _instances = {} | |||
| def __call__(cls, *args, **kwargs): | |||
| if cls not in cls._instances: | |||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
| return cls._instances[cls] | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||
| from ..report_generator import ReportGenerator | |||
| class CodeStruct: | |||
| @@ -88,14 +82,9 @@ class CodeStruct: | |||
| repeated_submodules (dict): The dict contains all submodules which use repeatedly. | |||
| Can get this dict from generator. | |||
| """ | |||
| # Define tmp var for code generation. | |||
| opt_var_name_records = dict() # now only support multiple outputs within same scope. | |||
| return_value_records = dict() # save returned values for successor nodes/modules use. | |||
| # Define Module header code line below | |||
| if md_struct.pattern_id != -1: | |||
| class_name = f"Module{md_struct.pattern_id}" | |||
| else: | |||
| class_name = "Model" | |||
| class_name = md_struct.class_name | |||
| # define a class declaration | |||
| self.new_line = f"class {class_name}(nn.Cell):" | |||
| @@ -108,36 +97,18 @@ class CodeStruct: | |||
| for formal in md_struct.args_translator.formal_args.keys(): | |||
| module_def_args.append(formal) | |||
| # Collect extra inputs and outputs | |||
| # For code line in init & construct blocks | |||
| init_lines = list() | |||
| cons_lines = list() | |||
| for (_, struct) in md_struct.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): # Generate code line for Node. | |||
| code_line_init = struct.code_line_in_init() | |||
| code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) | |||
| code_line_construct = struct.code_line_in_construct() | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||
| # add extra tensor | |||
| if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | |||
| code_extra_tensor = struct.add_extra_tensor() | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") | |||
| # record opt_var_name for succ nodes input in same scope. | |||
| target_onnx_name = struct.graph_node_ref.successor_nodes | |||
| for name in target_onnx_name: | |||
| if opt_var_name_records.get(name): | |||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||
| else: | |||
| opt_var_name_records[name] = [code_line_construct[0]] | |||
| if struct.successor_nodes_names_external: | |||
| for ret_user in struct.successor_nodes_names_external: | |||
| if return_value_records.get(ret_user) is not None: | |||
| return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) | |||
| else: | |||
| return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}") | |||
| elif isinstance(struct, ModuleStruct): | |||
| # check if this instance generated CodeStruct | |||
| @@ -149,22 +120,6 @@ class CodeStruct: | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||
| # record opt_var_name for succ nodes input in same scope. | |||
| target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes | |||
| for name in target_onnx_name: | |||
| if opt_var_name_records.get(name): | |||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||
| else: | |||
| opt_var_name_records[name] = [code_line_construct[0]] | |||
| # record submodule's local return map for following nodes / submodules use | |||
| if struct.external_successor_local_returns_map: | |||
| for ret_user, _ in struct.external_successor_local_returns_map.items(): | |||
| if return_value_records.get(ret_user) is not None: | |||
| # mulitple returns of a node may need modifiy the index. | |||
| return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) | |||
| else: | |||
| return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] | |||
| else: | |||
| raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") | |||
| @@ -183,8 +138,7 @@ class CodeStruct: | |||
| # define returns | |||
| returns = [] | |||
| if md_struct.external_successor_local_returns_map: | |||
| ret = list(md_struct.external_successor_local_returns_map.values()) | |||
| for r in ret: | |||
| for r in list(md_struct.external_successor_local_returns_map.values()): | |||
| if isinstance(r, tuple): # results return with index nth output | |||
| returns.append(r[0]) | |||
| else: | |||
| @@ -197,7 +151,7 @@ class CodeStruct: | |||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||
| class Generator(metaclass=Singleton): | |||
| class Generator: | |||
| """The generator controls all routines of code generation.""" | |||
| def __init__(self): | |||
| @@ -217,6 +171,7 @@ class Generator(metaclass=Singleton): | |||
| self._global_context.node_struct_collections = self._node_struct_collections | |||
| self._repeated_submodules = set() | |||
| @GeneratorFail.check_except("Generator occurs an error when forming base submodules.") | |||
| def _form_bottom_submodule(self): | |||
| """Form the basic submodules, which only contains nodes.""" | |||
| # Form module map | |||
| @@ -375,7 +330,7 @@ class Generator(metaclass=Singleton): | |||
| if scope_path_str == '[]': | |||
| continue # is main module, skip | |||
| if md_struct.scope_depth != depth: | |||
| continue # skip all submodules not at current depth | |||
| continue # skip all submodules not at current depth | |||
| md_struct_scope = copy.deepcopy(md_struct.identifier) | |||
| md_struct_scope.pop() | |||
| parent_scope = md_struct_scope | |||
| @@ -396,6 +351,7 @@ class Generator(metaclass=Singleton): | |||
| self._global_context.add_module_struct(sub.pattern_id, sub) | |||
| depth -= 1 | |||
| @GeneratorFail.check_except("Generator occurs an error when building modules.") | |||
| def _recursive_form_module(self): | |||
| """Main routine in generator to build modules from bottom to top.""" | |||
| # 1. List repeated submodules | |||
| @@ -518,6 +474,7 @@ class Generator(metaclass=Singleton): | |||
| """Return all ModuleStructs in this model.""" | |||
| return self._module_struct_collections | |||
| @GeneratorFail.check_except("Generator occurs an error when generating code statements.") | |||
| def generate(self): | |||
| """ | |||
| Generate the final script file. | |||
| @@ -527,8 +484,22 @@ class Generator(metaclass=Singleton): | |||
| """ | |||
| self._form_bottom_submodule() | |||
| self._recursive_form_module() | |||
| code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||
| return code.code_line_list | |||
| CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||
| outputs = [get_imported_module()] | |||
| for code_struct in self._global_context.code_structs.values(): | |||
| for line in code_struct.code_line_list: | |||
| outputs.append(line) | |||
| formatted_code, _ = FormatCode("\n".join(outputs), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| report_generator = ReportGenerator() | |||
| report = report_generator.gen_report(formatted_code) | |||
| del self._global_context | |||
| return {"model": (formatted_code, report)} | |||
| def get_node_struct(self, node_identifier): | |||
| """ | |||
| @@ -554,28 +525,6 @@ class Generator(metaclass=Singleton): | |||
| """ | |||
| return self._module_struct_collections.get(module_identifier, None) | |||
| def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): | |||
| """ | |||
| Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. | |||
| Args: | |||
| pattern_id (int): The pattern id the returned ModuleSturct is. | |||
| under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. | |||
| Returns: | |||
| list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. | |||
| """ | |||
| if not pattern_id: | |||
| raise ValueError("pattern_id is necessary to get the module struct.") | |||
| if not under_parent_pattern_id: | |||
| raise ValueError("under_parent_pattern_id is necessary to get the module struct.") | |||
| ret = [] | |||
| md_list = self._global_context.module_structs.get(pattern_id) | |||
| for md in md_list: | |||
| if md.parent_id == under_parent_pattern_id: | |||
| ret.append(md) | |||
| return ret | |||
| def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): | |||
| """ | |||
| Return a list of args translators which belongs to given module structs. | |||
| @@ -201,7 +201,7 @@ class ModuleStruct: | |||
| def init_args_translator(self): | |||
| """Initialize the Args Translator for the module.""" | |||
| var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) | |||
| var_name = self.ms_var_name | |||
| self._args_translator = ArgsTranslation(None, var_name, None) | |||
| def update_module_fragment(self): | |||
| @@ -455,14 +455,18 @@ class ModuleStruct: | |||
| """Return the class name for generating code of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| return "Module{}".format(self.pattern_id) | |||
| if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||
| class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) | |||
| else: | |||
| class_name = "Module{}".format(self.pattern_id) | |||
| return class_name | |||
| @property | |||
| def ms_var_name(self) -> str: | |||
| """Return the variable name for generated code statement of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() | |||
| return f"{self.class_name}_{self.pattern_uid}".lower() | |||
| @property | |||
| def ms_opt_var_name(self) -> str: | |||
| @@ -22,6 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import InputType | |||
| from ...common.exceptions import GeneratorFail | |||
| class NodeStruct: | |||
| @@ -47,22 +48,32 @@ class NodeStruct: | |||
| self.node_type = None | |||
| self.onnx_name = None | |||
| self.onnx_op = None | |||
| self.graph_node_ref = None # Our defined GraphNode | |||
| self.graph_node_ref = None | |||
| self.scope_name = None | |||
| self.ms_var_name = None | |||
| self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) | |||
| self.ms_opt_var_name = None | |||
| self.ms_op = None | |||
| self.ready_to_generate = False | |||
| self.ms_params = dict() # converted params from mapper | |||
| # Define attributes converted from mapper | |||
| self.ms_params = dict() | |||
| self.ms_settings = dict() | |||
| self.ms_weights = dict() | |||
| self.ms_inputs = OrderedDict() | |||
| self.scope = None # Defined Scope class | |||
| self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use | |||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||
| # Defined Scope class | |||
| self.scope = None | |||
| # Define attributes used for code generation | |||
| # key is prec_node_name, value is x; For code line use | |||
| self.inputs_in_construct_header = OrderedDict() | |||
| # key is prec_node_name, value is its closet opt_var_name | |||
| self.inputs_in_parent_module = OrderedDict() | |||
| # Matched inputs will can be directly used by code line generation | |||
| self.matched_inputs = list() | |||
| # initialize funcs. | |||
| for arg in args: | |||
| @@ -135,6 +146,7 @@ class NodeStruct: | |||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | |||
| self.scope = Scope(parsed_scope) | |||
| @GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.") | |||
| def init_args_translator(self, translated_args: list): | |||
| """ | |||
| Initialize the ArgsTranslator for each Node. | |||
| @@ -158,6 +170,7 @@ class NodeStruct: | |||
| self.ms_op]): | |||
| self.ready_to_generate = True | |||
| @GeneratorFail.check_except("Generator occurs an error when creating node struct.") | |||
| def update(self, arg, force_ready=False): | |||
| """ | |||
| Pass Node info. to generator NodeStruct. | |||
| @@ -314,23 +327,12 @@ class NodeStruct: | |||
| return input_name_to_use | |||
| return None | |||
| def code_line_in_construct(self, inputs=None, in_module_returns=None): | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block. """ | |||
| left = self.ms_opt_var_name | |||
| if inputs is None: | |||
| inputs = [] | |||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||
| if self.inputs_in_construct_header.get(prec_node): | |||
| inputs.append(self.inputs_in_construct_header.get(prec_node)) | |||
| elif self._check_target_node_internal(prec_node): | |||
| inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) | |||
| elif self.inputs_in_parent_module.get(prec_node): | |||
| inputs.append(self.inputs_in_parent_module.get(prec_node)) | |||
| elif in_module_returns and in_module_returns.get(self.onnx_name) \ | |||
| and (not self._check_target_node_internal(prec_node)): | |||
| inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) | |||
| else: | |||
| inputs.append("unk_{}_{}".format(idx, prec_node)) | |||
| if not self.matched_inputs and inputs is None: | |||
| raise ValueError("Unable to generate the code construct statement due to empty inputs.") | |||
| if self.matched_inputs: | |||
| inputs = self.matched_inputs | |||
| @@ -394,7 +396,7 @@ class NodeStruct: | |||
| """ | |||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | |||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| return False | |||
| if target_nd_struct is None: | |||
| @@ -413,11 +415,11 @@ class NodeStruct: | |||
| @property | |||
| def precursor_nodes_names_external(self) -> list: | |||
| """Return a list of external precursor nodes names.""" | |||
| return [name for name in self.precursor_nodes_names \ | |||
| if not self._check_target_node_internal(name)] | |||
| return [name for name in self.precursor_nodes_names | |||
| if not self._check_target_node_internal(name)] | |||
| @property | |||
| def successor_nodes_names_external(self) -> list: | |||
| """Return a list of external successor nodes names.""" | |||
| return [name for name in self.successor_nodes_names \ | |||
| if not self._check_target_node_internal(name)] | |||
| return [name for name in self.successor_nodes_names | |||
| if not self._check_target_node_internal(name)] | |||
| @@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE | |||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module | |||
| from ..constant import CodeFormatConfig | |||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| @@ -283,7 +283,7 @@ class HierarchicalTree(Tree): | |||
| Returns: | |||
| Dict, codes. | |||
| """ | |||
| code_blocks = [self._get_imported_module()] | |||
| code_blocks = [get_imported_module()] | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| @@ -741,17 +741,3 @@ class HierarchicalTree(Tree): | |||
| """Adjust tree structure to generate source code.""" | |||
| self.sub_graph_merging() | |||
| self.update_hierarchical_order() | |||
| @staticmethod | |||
| def _get_imported_module(): | |||
| """ | |||
| Generate imported module header. | |||
| Returns: | |||
| str, imported module. | |||
| """ | |||
| return f"import numpy as np{NEW_LINE}" \ | |||
| f"import mindspore{NEW_LINE}" \ | |||
| f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||