From 666a0d35a359ddf8c180af4b887a0b3e53e93d92 Mon Sep 17 00:00:00 2001 From: liangtianshu Date: Wed, 9 Dec 2020 15:33:35 +0800 Subject: [PATCH] Updates on merging generator to MindConverter, and unnecessary logic clean --- .../mindconverter/common/exceptions.py | 29 +++++ .../graph_based_converter/constant.py | 14 +++ .../graph_based_converter/framework.py | 19 ++- .../generator/generator.py | 113 +++++------------- .../generator/module_struct.py | 10 +- .../generator/node_struct.py | 56 ++++----- .../hierarchical_tree/hierarchical_tree.py | 18 +-- 7 files changed, 120 insertions(+), 139 deletions(-) diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 15e42e42..fd22fbf5 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -39,11 +39,13 @@ class ConverterErrors(ScriptConverterErrors): MODEL_NOT_SUPPORT = 8 SCRIPT_GENERATE_FAIL = 9 REPORT_GENERATE_FAIL = 10 + NODE_CONVERSION_ERROR = 11 BASE_CONVERTER_FAIL = 000 GRAPH_INIT_FAIL = 100 TREE_CREATE_FAIL = 200 SOURCE_FILES_SAVE_FAIL = 300 + GENERATOR_FAIL = 400 class ScriptNotSupport(MindInsightException): @@ -163,6 +165,7 @@ class MindConverterException(Exception): class BaseConverterFail(MindConverterException): """Base converter failed.""" + def __init__(self, msg): super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, user_msg=msg) @@ -195,6 +198,7 @@ class BaseConverterFail(MindConverterException): class UnknownModel(MindConverterException): """The unknown model error.""" + def __init__(self, msg): super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, user_msg=msg) @@ -202,6 +206,7 @@ class UnknownModel(MindConverterException): class GraphInitFail(MindConverterException): """The graph init fail error.""" + def __init__(self, **kwargs): super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, user_msg=kwargs.get('msg', '')) @@ -230,6 +235,7 @@ class GraphInitFail(MindConverterException): class TreeCreateFail(MindConverterException): """The tree create fail.""" + def __init__(self, msg): super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, user_msg=msg) @@ -254,6 +260,7 @@ class TreeCreateFail(MindConverterException): class SourceFilesSaveFail(MindConverterException): """The source files save fail error.""" + def __init__(self, msg): super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, user_msg=msg) @@ -280,6 +287,7 @@ class SourceFilesSaveFail(MindConverterException): class ModelNotSupport(MindConverterException): """The model not support error.""" + def __init__(self, msg): super(ModelNotSupport, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, user_msg=msg, @@ -338,6 +346,7 @@ class ModelNotSupport(MindConverterException): class NodeInputMissing(MindConverterException): """The node input missing error.""" + def __init__(self, msg): super(NodeInputMissing, self).__init__(error=ConverterErrors.NODE_INPUT_MISSING, user_msg=msg, @@ -346,6 +355,7 @@ class NodeInputMissing(MindConverterException): class TreeNodeInsertFail(MindConverterException): """The tree node create fail error.""" + def __init__(self, msg): super(TreeNodeInsertFail, self).__init__(error=ConverterErrors.TREE_NODE_INSERT_FAIL, user_msg=msg, @@ -380,6 +390,7 @@ class TreeNodeInsertFail(MindConverterException): class NodeInputTypeNotSupport(MindConverterException): """The node input type NOT support error.""" + def __init__(self, msg): super(NodeInputTypeNotSupport, self).__init__(error=ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, user_msg=msg, @@ -388,6 +399,7 @@ class NodeInputTypeNotSupport(MindConverterException): class ScriptGenerateFail(MindConverterException): """The script generate fail error.""" + def __init__(self, msg): super(ScriptGenerateFail, self).__init__(error=ConverterErrors.SCRIPT_GENERATE_FAIL, user_msg=msg, @@ -421,6 +433,7 @@ class ScriptGenerateFail(MindConverterException): class ReportGenerateFail(MindConverterException): """The report generate fail error.""" + def __init__(self, msg): super(ReportGenerateFail, self).__init__(error=ConverterErrors.REPORT_GENERATE_FAIL, user_msg=msg, @@ -448,3 +461,19 @@ class ReportGenerateFail(MindConverterException): return output return _f return decorator + + +class GeneratorFail(MindConverterException): + """The Generator fail error.""" + + def __init__(self, msg): + super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, + user_msg=msg, + cls_code=ConverterErrors.GENERATOR_FAIL.value) + @classmethod + def raise_from(cls): + """Raise from exceptions below.""" + except_source = (ValueError, + TypeError, + cls) + return except_source diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 590e67cd..bf9a83f9 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -75,3 +75,17 @@ class InputType(Enum): class FrameworkType(Enum): PYTORCH = 0 TENSORFLOW = 1 + + +def get_imported_module(): + """ + Generate imported module header. + + Returns: + str, imported module. + """ + return f"import numpy as np{NEW_LINE}" \ + f"import mindspore{NEW_LINE}" \ + f"from mindspore import nn{NEW_LINE}" \ + f"from mindspore import Tensor{NEW_LINE}" \ + f"from mindspore.ops import operations as P{NEW_LINE * 3}" diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 28f792eb..1971fc1e 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -20,7 +20,8 @@ from importlib import import_module from importlib.util import find_spec import mindinsight -from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied +from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ + save_code_file_and_report from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper @@ -29,6 +30,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreat BaseConverterFail, UnknownModel from mindinsight.utils.exceptions import ParamMissError + permissions = os.R_OK | os.W_OK | os.X_OK os.umask(permissions << 3 | permissions) @@ -194,23 +196,18 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, """ third_party_graph_module = import_module( 'mindinsight.mindconverter.graph_based_converter.third_party_graph') - hierarchical_tree_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') - cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') + batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), + "batch_add_nodes") # Close unnecessary log. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, input_nodes=input_nodes, output_nodes=output_nodes) - - hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) - + generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) - hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, - model_name=model_name, - report_folder=report_folder, - scope_name_map=scope_name_map) + code_fragments = generator_inst.generate() + save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) @BaseConverterFail.check_except("Failed to start base converter.") diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index a9ba9d6e..c563dc01 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -16,23 +16,17 @@ import copy from collections import OrderedDict +from yapf.yapflib.yapf_api import FormatCode + from .scope_utils import Scope from .node_struct import NodeStruct from .module_struct import ModuleStruct from .args_translator import ArgsTranslationHelper from ..common.global_context import GlobalContext +from ...common.exceptions import GeneratorFail from ..hierarchical_tree.name_mgr import GlobalVarNameMgr -from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT - - -class Singleton(type): - """Metaclass to make the generator to be single instance.""" - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] +from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module +from ..report_generator import ReportGenerator class CodeStruct: @@ -88,14 +82,9 @@ class CodeStruct: repeated_submodules (dict): The dict contains all submodules which use repeatedly. Can get this dict from generator. """ - # Define tmp var for code generation. - opt_var_name_records = dict() # now only support multiple outputs within same scope. - return_value_records = dict() # save returned values for successor nodes/modules use. + # Define Module header code line below - if md_struct.pattern_id != -1: - class_name = f"Module{md_struct.pattern_id}" - else: - class_name = "Model" + class_name = md_struct.class_name # define a class declaration self.new_line = f"class {class_name}(nn.Cell):" @@ -108,36 +97,18 @@ class CodeStruct: for formal in md_struct.args_translator.formal_args.keys(): module_def_args.append(formal) - # Collect extra inputs and outputs - # For code line in init & construct blocks init_lines = list() cons_lines = list() for (_, struct) in md_struct.get_generate_order(): if isinstance(struct, NodeStruct): # Generate code line for Node. code_line_init = struct.code_line_in_init() - code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) + code_line_construct = struct.code_line_in_construct() init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") # add extra tensor if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: - code_extra_tensor = struct.add_extra_tensor() - init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") - - # record opt_var_name for succ nodes input in same scope. - target_onnx_name = struct.graph_node_ref.successor_nodes - for name in target_onnx_name: - if opt_var_name_records.get(name): - opt_var_name_records.get(name).append(code_line_construct[0]) - else: - opt_var_name_records[name] = [code_line_construct[0]] - - if struct.successor_nodes_names_external: - for ret_user in struct.successor_nodes_names_external: - if return_value_records.get(ret_user) is not None: - return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) - else: - return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}") elif isinstance(struct, ModuleStruct): # check if this instance generated CodeStruct @@ -149,22 +120,6 @@ class CodeStruct: init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") - # record opt_var_name for succ nodes input in same scope. - target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes - for name in target_onnx_name: - if opt_var_name_records.get(name): - opt_var_name_records.get(name).append(code_line_construct[0]) - else: - opt_var_name_records[name] = [code_line_construct[0]] - - # record submodule's local return map for following nodes / submodules use - if struct.external_successor_local_returns_map: - for ret_user, _ in struct.external_successor_local_returns_map.items(): - if return_value_records.get(ret_user) is not None: - # mulitple returns of a node may need modifiy the index. - return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) - else: - return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] else: raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") @@ -183,8 +138,7 @@ class CodeStruct: # define returns returns = [] if md_struct.external_successor_local_returns_map: - ret = list(md_struct.external_successor_local_returns_map.values()) - for r in ret: + for r in list(md_struct.external_successor_local_returns_map.values()): if isinstance(r, tuple): # results return with index nth output returns.append(r[0]) else: @@ -197,7 +151,7 @@ class CodeStruct: self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self -class Generator(metaclass=Singleton): +class Generator: """The generator controls all routines of code generation.""" def __init__(self): @@ -217,6 +171,7 @@ class Generator(metaclass=Singleton): self._global_context.node_struct_collections = self._node_struct_collections self._repeated_submodules = set() + @GeneratorFail.check_except("Generator occurs an error when forming base submodules.") def _form_bottom_submodule(self): """Form the basic submodules, which only contains nodes.""" # Form module map @@ -375,7 +330,7 @@ class Generator(metaclass=Singleton): if scope_path_str == '[]': continue # is main module, skip if md_struct.scope_depth != depth: - continue # skip all submodules not at current depth + continue # skip all submodules not at current depth md_struct_scope = copy.deepcopy(md_struct.identifier) md_struct_scope.pop() parent_scope = md_struct_scope @@ -396,6 +351,7 @@ class Generator(metaclass=Singleton): self._global_context.add_module_struct(sub.pattern_id, sub) depth -= 1 + @GeneratorFail.check_except("Generator occurs an error when building modules.") def _recursive_form_module(self): """Main routine in generator to build modules from bottom to top.""" # 1. List repeated submodules @@ -518,6 +474,7 @@ class Generator(metaclass=Singleton): """Return all ModuleStructs in this model.""" return self._module_struct_collections + @GeneratorFail.check_except("Generator occurs an error when generating code statements.") def generate(self): """ Generate the final script file. @@ -527,8 +484,22 @@ class Generator(metaclass=Singleton): """ self._form_bottom_submodule() self._recursive_form_module() - code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) - return code.code_line_list + CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) + + outputs = [get_imported_module()] + + for code_struct in self._global_context.code_structs.values(): + for line in code_struct.code_line_list: + outputs.append(line) + + formatted_code, _ = FormatCode("\n".join(outputs), + style_config=CodeFormatConfig.PEP8.value) + + report_generator = ReportGenerator() + report = report_generator.gen_report(formatted_code) + del self._global_context + + return {"model": (formatted_code, report)} def get_node_struct(self, node_identifier): """ @@ -554,28 +525,6 @@ class Generator(metaclass=Singleton): """ return self._module_struct_collections.get(module_identifier, None) - def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): - """ - Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. - - Args: - pattern_id (int): The pattern id the returned ModuleSturct is. - under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. - - Returns: - list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. - """ - if not pattern_id: - raise ValueError("pattern_id is necessary to get the module struct.") - if not under_parent_pattern_id: - raise ValueError("under_parent_pattern_id is necessary to get the module struct.") - ret = [] - md_list = self._global_context.module_structs.get(pattern_id) - for md in md_list: - if md.parent_id == under_parent_pattern_id: - ret.append(md) - return ret - def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): """ Return a list of args translators which belongs to given module structs. diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index ca0a0a90..d26149af 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -201,7 +201,7 @@ class ModuleStruct: def init_args_translator(self): """Initialize the Args Translator for the module.""" - var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) + var_name = self.ms_var_name self._args_translator = ArgsTranslation(None, var_name, None) def update_module_fragment(self): @@ -455,14 +455,18 @@ class ModuleStruct: """Return the class name for generating code of this module.""" if self.pattern_id == -1: return "Model" - return "Module{}".format(self.pattern_id) + if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: + class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) + else: + class_name = "Module{}".format(self.pattern_id) + return class_name @property def ms_var_name(self) -> str: """Return the variable name for generated code statement of this module.""" if self.pattern_id == -1: return "Model" - return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() + return f"{self.class_name}_{self.pattern_uid}".lower() @property def ms_opt_var_name(self) -> str: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index 9764fe2f..196330c4 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -22,6 +22,7 @@ from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..common.global_context import GlobalContext from ..constant import InputType +from ...common.exceptions import GeneratorFail class NodeStruct: @@ -47,22 +48,32 @@ class NodeStruct: self.node_type = None self.onnx_name = None self.onnx_op = None - self.graph_node_ref = None # Our defined GraphNode + self.graph_node_ref = None self.scope_name = None self.ms_var_name = None - self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) + self.ms_opt_var_name = None self.ms_op = None self.ready_to_generate = False - self.ms_params = dict() # converted params from mapper + # Define attributes converted from mapper + self.ms_params = dict() self.ms_settings = dict() self.ms_weights = dict() self.ms_inputs = OrderedDict() - self.scope = None # Defined Scope class - self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use - self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name - self.matched_inputs = list() # Matched inputs will can be directly used by code line generation + # Defined Scope class + self.scope = None + + # Define attributes used for code generation + + # key is prec_node_name, value is x; For code line use + self.inputs_in_construct_header = OrderedDict() + + # key is prec_node_name, value is its closet opt_var_name + self.inputs_in_parent_module = OrderedDict() + + # Matched inputs will can be directly used by code line generation + self.matched_inputs = list() # initialize funcs. for arg in args: @@ -135,6 +146,7 @@ class NodeStruct: parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) self.scope = Scope(parsed_scope) + @GeneratorFail.check_except("Generator occurs an error when initializing node's args translator.") def init_args_translator(self, translated_args: list): """ Initialize the ArgsTranslator for each Node. @@ -158,6 +170,7 @@ class NodeStruct: self.ms_op]): self.ready_to_generate = True + @GeneratorFail.check_except("Generator occurs an error when creating node struct.") def update(self, arg, force_ready=False): """ Pass Node info. to generator NodeStruct. @@ -314,23 +327,12 @@ class NodeStruct: return input_name_to_use return None - def code_line_in_construct(self, inputs=None, in_module_returns=None): + def code_line_in_construct(self, inputs=None): """Construct line of code in module construct block. """ left = self.ms_opt_var_name - if inputs is None: - inputs = [] - for idx, prec_node in enumerate(self.precursor_nodes_names): - if self.inputs_in_construct_header.get(prec_node): - inputs.append(self.inputs_in_construct_header.get(prec_node)) - elif self._check_target_node_internal(prec_node): - inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) - elif self.inputs_in_parent_module.get(prec_node): - inputs.append(self.inputs_in_parent_module.get(prec_node)) - elif in_module_returns and in_module_returns.get(self.onnx_name) \ - and (not self._check_target_node_internal(prec_node)): - inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) - else: - inputs.append("unk_{}_{}".format(idx, prec_node)) + + if not self.matched_inputs and inputs is None: + raise ValueError("Unable to generate the code construct statement due to empty inputs.") if self.matched_inputs: inputs = self.matched_inputs @@ -394,7 +396,7 @@ class NodeStruct: """ target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) - if target_nd_struct is None and self.topo_idx == 0: # First node always has external input + if target_nd_struct is None and self.topo_idx == 0: # First node always has external input return False if target_nd_struct is None: @@ -413,11 +415,11 @@ class NodeStruct: @property def precursor_nodes_names_external(self) -> list: """Return a list of external precursor nodes names.""" - return [name for name in self.precursor_nodes_names \ - if not self._check_target_node_internal(name)] + return [name for name in self.precursor_nodes_names + if not self._check_target_node_internal(name)] @property def successor_nodes_names_external(self) -> list: """Return a list of external successor nodes names.""" - return [name for name in self.successor_nodes_names \ - if not self._check_target_node_internal(name)] + return [name for name in self.successor_nodes_names + if not self._check_target_node_internal(name)] diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index 92176882..f90edc02 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report from ..mapper.base import Mapper from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode -from ..constant import SEPARATOR_IN_SCOPE +from ..constant import SEPARATOR_IN_SCOPE, get_imported_module from ..constant import CodeFormatConfig from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT @@ -283,7 +283,7 @@ class HierarchicalTree(Tree): Returns: Dict, codes. """ - code_blocks = [self._get_imported_module()] + code_blocks = [get_imported_module()] depths = sorted(list(self._hierarchical_order.keys()), reverse=True) for depth in depths: @@ -741,17 +741,3 @@ class HierarchicalTree(Tree): """Adjust tree structure to generate source code.""" self.sub_graph_merging() self.update_hierarchical_order() - - @staticmethod - def _get_imported_module(): - """ - Generate imported module header. - - Returns: - str, imported module. - """ - return f"import numpy as np{NEW_LINE}" \ - f"import mindspore{NEW_LINE}" \ - f"from mindspore import nn{NEW_LINE}" \ - f"from mindspore import Tensor{NEW_LINE}" \ - f"from mindspore.ops import operations as P{NEW_LINE * 3}"