diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index c0ce4c71..836c911c 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -40,8 +40,10 @@ class GlobalContext(metaclass=Singleton): # Define data stored from onnx_utils # Key as Onnx Name self._onnx_nodes_collection = OrderedDict() - # key is topo_idx, value is onnx_node_name. + # key is topo_idx, value is onnx_node_name self._onnx_nodes_topo_index = dict() + self.onnx_node_name_to_topo_idx = dict() + self.onnx_node_inputs = dict() self._onnx_tensors_collection = dict() # Define data stored from generator @@ -50,7 +52,7 @@ class GlobalContext(metaclass=Singleton): self.node_struct_adder_counter = 0 # Define onnx_utils <---> generator mapping self.node_struct_to_onnx_node_map = dict() - self.onnx_node_to_node_struct_map = dict() + self.onnx_node_name_to_node_struct_map = dict() # Define Module pattern to customize name mapping self.module_customized_name = dict() @@ -59,6 +61,8 @@ class GlobalContext(metaclass=Singleton): self.node_fragments = OrderedDict() self.module_fragments = OrderedDict() + # Define Known module mapping + self.known_module_name = dict() # Define Structs # key is pattern_id, value is [ModuleStructs] self.module_structs = dict() @@ -83,7 +87,7 @@ class GlobalContext(metaclass=Singleton): def get_identifier_from_onnx_node_name(self, node_name): """Return the node identifier by Onnx Node name.""" - identifier = self.onnx_node_to_node_struct_map.get(node_name) + identifier = self.onnx_node_name_to_node_struct_map.get(node_name) return identifier @property @@ -98,9 +102,7 @@ class GlobalContext(metaclass=Singleton): @onnx_nodes_collection.setter def onnx_nodes_collection(self, arg): - """ - Set the onnx nodes collection. - """ + """Set the onnx nodes collection.""" if isinstance(arg, OrderedDict): self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader else: @@ -108,11 +110,18 @@ class GlobalContext(metaclass=Singleton): @property def onnx_nodes_topo_index(self) -> dict: - "Return the onnx nodes and topological index." + """Return the onnx nodes and topological index.""" return self._onnx_nodes_topo_index @onnx_nodes_topo_index.setter def onnx_nodes_topo_index(self, index_list): + """ + Set the onnx nodes and topological index. + + Args: + index_list (list[tuple[int, str]]): a list of tuple contains the topological index and onnx node name. + + """ if not isinstance(index_list, list): raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") if not isinstance(index_list[0], tuple): @@ -122,10 +131,17 @@ class GlobalContext(metaclass=Singleton): @property def onnx_tensors_collection(self): + """Return the onnx tensors collection.""" return self.onnx_tensors_collection @onnx_tensors_collection.setter def onnx_tensors_collection(self, arg): + """ + Set the onnx tensors collection by OnnxDataLoader. + + Args: + arg (dict): The OnnxDataLoader generated tensors_dict. + """ if isinstance(arg, dict): self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader else: @@ -133,6 +149,12 @@ class GlobalContext(metaclass=Singleton): @property def latest_node_struct_count(self): + """ + Return the latest node struct count. + + Note: + The counter will increase by 1 to tracking the number of nodes added. + """ ret = self.node_struct_adder_counter self.node_struct_adder_counter += 1 return ret @@ -184,18 +206,29 @@ class GlobalContext(metaclass=Singleton): self.module_customized_name[pattern_id] = customized_name def get_node_fragment(self, identifier): + """Return the node fragment by identifier.""" return self.node_fragments.get(identifier) def add_code_fragment(self, identifier, frag): + """Add the node fragment by identifier.""" self.node_fragments[identifier] = frag def get_module_fragment(self, identifier): + """Return the module fragment by identifier.""" return self.module_fragments.get(identifier) def add_module_fragment(self, identifier, frag): + """Add the module fragment by identifier.""" self.module_fragments[identifier] = frag def add_module_struct(self, pattern_id, module_struct): + """ + Add module struct by its pattern_id. + + Args: + pattern_id (int): The pattern which represents the structure of the module. + module_struct (ModuleStruct): The ModuleStruct instance. + """ if self.module_structs.get(pattern_id) is None: self.module_structs[pattern_id] = [module_struct] else: diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index b946e832..fd632ccf 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -135,3 +135,18 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): return False return True +def get_dict_key_by_value(val, dic): + """ + Return the first appeared key of a dictionay by given value. + + Args: + val (Any): Value of the key. + dic (dict): Dictionary to be checked. + + Returns: + Any, key of the given value. + """ + for d_key, d_val in dic.items(): + if d_val == val: + return d_key + return None diff --git a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py new file mode 100644 index 00000000..483a93b1 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generator module.""" +__all__ = ["batch_add_nodes"] + +import re +import copy + +from .generator import Generator, CodeStruct +from ..common.code_fragment import CodeFragment + + +def _tf_model_node_name_reformat(node, node_name): + """ + Rename the node name by combining scope name and its original name. + + Args: + node (OnnxGraphNode): OnnxGraphNode instance. + node_name (str): node name saved in Graph. + + Returns: + str, re-formatted node name. + """ + scope_name = node.scope_name + new_name = None + regex = r"(?P.+/)(?P\w+)" + match = re.match(regex, scope_name) + parent = match.group("parent") + node_name = '$' + node_name.replace('/', '::') + '$' + + if scope_name: + new_name = parent + node_name + return new_name + return node_name + + +def batch_add_nodes(graph_obj, mapper) -> Generator: + """ + Add nodes to Generator in batch mode. + + Args: + graph_obj (Graph): Graph obj. + mapper (Mapper): Mapper of third party framework and MindSpore. + + """ + generator_inst = Generator() + for node_name in graph_obj.nodes_in_topological_order: + node_inst = graph_obj.get_node(node_name) + node_input = graph_obj.get_input_shape(node_name) + node_output = graph_obj.get_output_shape(node_name) + if not node_input: + raise ValueError("Unable to get the node's inputs from Graph object.") + node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) + node_name = node_name_with_scope + + node_inst.add_input_and_output_shape(node_input, node_output) + op_name, params, settings, weights = _convert_params(node_inst, mapper) + generator_inst.add_node( + node_name, + node_instance=node_inst, + node_fragment=CodeFragment(op_name, params, + settings, + node_inst.input_shape, + node_inst.output_shape, + weights) + ) + return generator_inst + + +def _convert_params(node, mapper): + """ + Call mapper to convert node's params from ONNX to MindSpore. + + Args: + node (GraphNode): Our defined GraphNode instance. + mapper (Mapper): The mapper instance which indicating conversion method. + + Returns: + str, op name in MindSpore + dict, MindSpore parameters + dict, MindSpore settings + dict, weights of the node + """ + params = copy.deepcopy(node.node_params) + params.update({"input_shape": node.input_shape, + "output_shape": node.output_shape}) + + op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, + params=params, + weights=node.weight) + if "input_shape" in ms_params: + ms_params.pop("input_shape") + if "output_shape" in ms_params: + ms_params.pop("output_shape") + + if op_in_ms: + return op_in_ms, ms_params, ms_settings, weights + + return node.op_name, node.node_params, dict(), dict() diff --git a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py new file mode 100644 index 00000000..3219527b --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py @@ -0,0 +1,248 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define arguments translation related operations for params name changing.""" +import copy + + +class ArgsTranslation: + """Define a universal arguments translation manager.""" + + def __init__(self, original_actual_args: dict, var_name: str, translated_args: list): + """ + Init the ArgsTranslation. + + Args: + original_actual_args (dict): The full original args from fragments. + var_name (str): The var name for current Node / Module. + translated_args (list): The list of args need to translate to formal args. + """ + if not var_name: + raise ValueError("Initialize ArgsTranslation requires the var_name.") + + self.var_name = var_name + self.actual_args = dict() # e.g. key is 'num_features', value is 2048 + self.formal_args = dict() # e.g. key is 'num_features', value is 'var_name_num_features'} + self.formal_args_values = dict() # e.g. key 'var_name_num_features', value 2048. Value use for up-level + self.actual_args_backup = dict() # backup actual args before translation + + self.actual_args_to_str_list = list() + self.formal_args_to_str_list = list() + self.formal_args_values_to_str_list = list() + self.actual_args_backup_to_str_list = list() + + if all([original_actual_args, translated_args]): + # MUST ensure only one var_name in a scope. + for arg_name, arg_value in original_actual_args.items(): + if arg_name in translated_args: + formal_arg_name = '_'.join([var_name, arg_name]) + self.formal_args[arg_name] = formal_arg_name + self.formal_args_values[formal_arg_name] = arg_value + else: + self.actual_args[arg_name] = arg_value + + self.make_str() + + @staticmethod + def dict_data_to_args_str_list(any_dict): + """ + Output a list of string of dict data by "key=value" format. + + Args: + any_dict (dict): Any dictionary + + Returns: + list, the list of strings showing dictionary as "key=value" format. + """ + ret = [] + for key, val in any_dict.items(): + ret.append('='.join([key, str(val)])) + return ret + + def make_str(self): + """Make string used in code generation.""" + self.actual_args_to_str_list = list() + self.formal_args_to_str_list = list() + self.formal_args_values_to_str_list = list() + self.actual_args_backup_to_str_list = list() + + if self.actual_args: + self.actual_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args) + + if self.formal_args: + self.formal_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args) + + if self.formal_args_values: + self.formal_args_values_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args_values) + + if self.actual_args_backup: + self.actual_args_backup_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args_backup) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "var_name": self.var_name, + "actual_args": self.actual_args, + "actual_bak": self.actual_args_backup, + "formal_args": self.formal_args, + "formal_val ": self.formal_args_values + }) + + def set_actual_args_backup(self): + """Backup the actual args before translating to formal.""" + self.actual_args_backup = copy.deepcopy(self.actual_args) + + def deepcopy(self): + """Return a deepcopy of self.""" + return copy.deepcopy(self) + + def make_actual_arg_to_formal(self, actual_arg_name): + """ + Make the actual arg to a formal arg. + + Args: + actual_arg_name (str): The name of the actual arg to be formal. + """ + val = self.actual_args.get(actual_arg_name) + if val is None: + raise ValueError("Unable to convert the actual arg to formal due to missing arg.") + formal_arg_name = ('_').join([self.var_name, actual_arg_name]) + self.actual_args.pop(actual_arg_name) + self.formal_args[actual_arg_name] = formal_arg_name + self.formal_args_values[formal_arg_name] = val + self.make_str() + + def _update_dict_for_upper_level(self, d, upper_level_var_name): + """Add upper level var name to key name of selected dictionary.""" + new_d = dict() + for arg_name, val in d.items(): + new_arg_name = '_'.join([upper_level_var_name, arg_name]) # e.g. conv2d_0_in_channels_Module_3_0 + new_d[new_arg_name] = val + return new_d + + def escalate_to_upper_level(self, upper_level_var_name): + """ + Escalate this args translator for upper level module use. + + Note: + You MUST deepcopy this translator first to avoid editing values in the original translator. + """ + # update all args name by adding upper_level_var_name. + tmp_actual_args = self._update_dict_for_upper_level(self.actual_args, upper_level_var_name) + tmp_formal_args = self._update_dict_for_upper_level(self.formal_args, upper_level_var_name) + tmp_formal_args_values = self._update_dict_for_upper_level(self.formal_args_values, upper_level_var_name) + + self.actual_args = tmp_actual_args + self.formal_args = tmp_formal_args + self.formal_args_values = tmp_formal_args_values + + self.make_str() + + def make_formal_args_back_to_actual(self, formal_arg): + """ + Move the formal arg back to actual arg. + + Note: + This does not reset the formal arg name back, + Only used for module init statement. + + Args: + formal_arg (str): formal argument name. + """ + if isinstance(formal_arg, str): + val = self.formal_args_values.pop(formal_arg) + self.actual_args[formal_arg] = val + if isinstance(formal_arg, list): + for arg in formal_arg: + val = self.formal_args_values.pop(arg) + self.actual_args[formal_arg] = val + + self.make_str() + + def take_formal_args_from_args_translator(self, args_translator, escalate_sub=False): + """ + Add submodule's or node's args translator to this translator. + + Args: + args_translator (ArgsTranslation): submodule's or node's args translator. + """ + if escalate_sub: + sub_args_translator = args_translator.deepcopy() + sub_args_translator.escalate_to_upper_level(self.var_name) + else: + sub_args_translator = args_translator + + original_actual_args = sub_args_translator.formal_args_values + self.actual_args.update(original_actual_args) + self.make_str() + + def take_formal_args_from_nodes_and_submodules(self, args_translators: list, escalate_sub=False): + """ + Take all formal args from nodes and submodules from passed in args_translators. + + Args: + args_translators (ArgsTranslation): A list of ArgsTranslation instances. + escalate_sub (Bool): should escalate all formal args. Default: False + """ + for arg_t in args_translators: + self.take_formal_args_from_args_translator(arg_t, escalate_sub=escalate_sub) + + +class ArgsTranslationHelper: + """Define operations related to ArgsTranslation instances.""" + @staticmethod + def find_formal_args_in_modules(args_translators): + """ + Find formal args among multiple args translators. + + Args: + args_translators(list[ArgsTranslation]): a list of args translator to be checked. + + Returns: + list, name of args to be formal. + """ + if len(args_translators) < 2: + # only one args_translator provided, no formal args. + return None + ret = [] + base_args_t = args_translators[0] + for arg_name, arg_val in base_args_t.actual_args.items(): + for args_t in args_translators[1:]: + val = args_t.actual_args.get(arg_name) + + if val is None: + raise ValueError("Unable to find the given args as the args translator is not consistent.") + if val != arg_val: # val not equal + ret.append(arg_name) + break + return ret + + @staticmethod + def change_args_to_formal_for_all_translators(args_name, args_translators): + """ + Change args to formal for all translators provided. + + Args: + args_name (str): The name of args to be changing. + args_translators (ArgsTranslation): The args to be changed in args translators. + """ + if isinstance(args_name, str): + args_name = [args_name] + if isinstance(args_translators, ArgsTranslation): + args_translators = [args_translators] + + for arg in args_name: + for args_t in args_translators: + args_t.set_actual_args_backup() + args_t.make_actual_arg_to_formal(arg) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py new file mode 100644 index 00000000..a9ba9d6e --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -0,0 +1,630 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Main Generator module.""" +import copy +from collections import OrderedDict + +from .scope_utils import Scope +from .node_struct import NodeStruct +from .module_struct import ModuleStruct +from .args_translator import ArgsTranslationHelper +from ..common.global_context import GlobalContext +from ..hierarchical_tree.name_mgr import GlobalVarNameMgr +from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT + + +class Singleton(type): + """Metaclass to make the generator to be single instance.""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class CodeStruct: + """ + Define the Code template for each module generated in the final output. + Each module has only one CodeStruct to its pattern. + """ + GLOBAL_CONTEXT = GlobalContext() + NOT_IN_SCOPE_OPT = dict() + + def __init__(self, struct, repeated_submodules=None): + """Initialize the CodeStruct.""" + self.output_order = None # output order + self.input = None # opt_var_name for prev. node + self.extra_input = list() # extra_input(s) at construct method args + self.output = None # opt_var_name for next node + self.extra_output = list() # extra_output(s) + self.extra_comment = None # comments for this code line / block. + self.code_line_list = list() # list of code line, a item is a line. + self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module + + self.formal_args_collections = None + + if isinstance(struct, NodeStruct): + self.output_order = struct.topo_idx + if isinstance(struct, ModuleStruct): + self.output_order = struct.head_nd_struct_index + self._generate_from_module_struct(struct, repeated_submodules) + + def _add_line(self, s): + """Add line of code.""" + self.code_line_list.append(s) + + @property + def new_line(self): + """Return last generated line.""" + try: + return self.code_line_list[-1] + except IndexError: + return "" + + @new_line.setter + def new_line(self, s): + """Make a new line.""" + self._add_line(s) + + def _generate_from_module_struct(self, md_struct, repeated_submodules): + """ + Generate the code of current Module Struct, collecting data from submodules. + + Args: + md_struct (ModuleStruct): The ModuleStruct which generates codes. + repeated_submodules (dict): The dict contains all submodules which use repeatedly. + Can get this dict from generator. + """ + # Define tmp var for code generation. + opt_var_name_records = dict() # now only support multiple outputs within same scope. + return_value_records = dict() # save returned values for successor nodes/modules use. + # Define Module header code line below + if md_struct.pattern_id != -1: + class_name = f"Module{md_struct.pattern_id}" + else: + class_name = "Model" + # define a class declaration + self.new_line = f"class {class_name}(nn.Cell):" + + # Get all formal args from nodes + module_def_args = ['self'] + if md_struct.args_translator.actual_args: + for actual in md_struct.args_translator.actual_args.keys(): + module_def_args.append(actual) + if md_struct.args_translator.formal_args: + for formal in md_struct.args_translator.formal_args.keys(): + module_def_args.append(formal) + + # Collect extra inputs and outputs + + # For code line in init & construct blocks + init_lines = list() + cons_lines = list() + for (_, struct) in md_struct.get_generate_order(): + if isinstance(struct, NodeStruct): # Generate code line for Node. + code_line_init = struct.code_line_in_init() + code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") + cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") + # add extra tensor + if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: + code_extra_tensor = struct.add_extra_tensor() + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") + + # record opt_var_name for succ nodes input in same scope. + target_onnx_name = struct.graph_node_ref.successor_nodes + for name in target_onnx_name: + if opt_var_name_records.get(name): + opt_var_name_records.get(name).append(code_line_construct[0]) + else: + opt_var_name_records[name] = [code_line_construct[0]] + + if struct.successor_nodes_names_external: + for ret_user in struct.successor_nodes_names_external: + if return_value_records.get(ret_user) is not None: + return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) + else: + return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] + + elif isinstance(struct, ModuleStruct): + # check if this instance generated CodeStruct + if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: + CodeStruct(struct, repeated_submodules) + + code_line_init = struct.code_line_in_init() + code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs) + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") + cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") + + # record opt_var_name for succ nodes input in same scope. + target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes + for name in target_onnx_name: + if opt_var_name_records.get(name): + opt_var_name_records.get(name).append(code_line_construct[0]) + else: + opt_var_name_records[name] = [code_line_construct[0]] + + # record submodule's local return map for following nodes / submodules use + if struct.external_successor_local_returns_map: + for ret_user, _ in struct.external_successor_local_returns_map.items(): + if return_value_records.get(ret_user) is not None: + # mulitple returns of a node may need modifiy the index. + return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) + else: + return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] + else: + raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") + + # define header of init block + self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" + self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" + # add init code lines to code line list. + self.code_line_list += init_lines + self.new_line = f"{NEW_LINE * 2}" + + # define header of construct block + inputs = ['self'] + list(md_struct.construct_header_x.keys()) + self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" + # add construct code lines to code line list. + self.code_line_list += cons_lines + # define returns + returns = [] + if md_struct.external_successor_local_returns_map: + ret = list(md_struct.external_successor_local_returns_map.values()) + for r in ret: + if isinstance(r, tuple): # results return with index nth output + returns.append(r[0]) + else: + returns.append(r) + returns = list(set(returns)) + else: + returns = [code_line_construct[0]] + self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" + self.new_line = f"{NEW_LINE * 2}" + self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self + + +class Generator(metaclass=Singleton): + """The generator controls all routines of code generation.""" + + def __init__(self): + """Init the generator.""" + # define basic attributes + self.framework = None + + # define MUST have params + self._node_struct_collections = OrderedDict() + self._module_struct_collections = OrderedDict() + self._module_depth_max = 0 + self._module_depth_min = 0 + + # define intermediate var. during conversion + self._module_map = OrderedDict() + self._global_context = GlobalContext() + self._global_context.node_struct_collections = self._node_struct_collections + self._repeated_submodules = set() + + def _form_bottom_submodule(self): + """Form the basic submodules, which only contains nodes.""" + # Form module map + curr_scope_path = None + nd_struct_list_in_submodule = [] + for nd_struct in self.node_structs.values(): + idx = nd_struct.topo_idx + if curr_scope_path is None: + curr_scope_path = nd_struct.scope.path + nd_struct_list_in_submodule.append((idx, nd_struct)) + elif curr_scope_path == nd_struct.scope.path: + nd_struct_list_in_submodule.append((idx, nd_struct)) + else: # curr_scope_path changed + # save this submodule + if self._module_map.get(str(curr_scope_path)) is not None: + self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule + else: + self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule + + # create a new one + curr_scope_path = nd_struct.scope.path + nd_struct_list_in_submodule = [(idx, nd_struct)] + + # save last submodule + if self._module_map.get(str(curr_scope_path)) is not None: + self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule + else: + self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule + + # Form bottom modules' ModuleStruct + for scope_path_str, nd_struct_list in self._module_map.items(): + self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list) + + def _list_repeated_submodules(self) -> OrderedDict: + """ + Return the repeated submodules by its depth and num. + For example, "Model/Module3_3" will return {1:(3)} + + Return: + OrderedDict, a dict contains collections of repeated submodules. + """ + ret = OrderedDict() + for depth_control in range(self._module_depth_max, 0, -1): + repeated_submodules_at_this_depth = set() + for scope_path in self._module_map.keys(): + path = Scope.path_str_to_list(scope_path) + if len(path) < depth_control: + continue + else: # depth control within path length + module_num = path[depth_control - 1][0] + repeated_submodules_at_this_depth.add(module_num) + ret[depth_control] = repeated_submodules_at_this_depth + + self._repeated_submodules = ret + return ret + + def _compare_with_base_parameters(self, nd_struct_list): + """ + Compare the parameter to check if it should be a formal args. + + Args: + nd_struct_list (list): A list of NodeStructs which contains + all same nodes in repeated submodules. + + Return: + set, a set of all formal args in this node. + """ + + formal_args = set() + if len(nd_struct_list) < 2: + return formal_args + (_, base_nd_struct) = nd_struct_list[0] + for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param + for (_, nd_struct) in nd_struct_list[1:]: + compared_value = nd_struct.fragment.actual_args.get(base_parameter) + if compared_value == base_value: + continue + else: + formal_args.add(base_parameter) + break + + return formal_args + + def _list_formal_parameters_in_a_module(self, module_filter_return): + """ + Find all formal args / params from nodes in a module. + + Args: + module_filter_return (dict): The filtered results from the module_map_filter. + + Return: + list, a list of sets or None indicates all formal args of each node in the module in order. + """ + formal_params_list = list() + transposed = [list(e) for e in zip(*module_filter_return)] + for operation in transposed: + formal_parameters = self._compare_with_base_parameters(operation) + if formal_parameters: + formal_params_list.append(formal_parameters) + else: + formal_params_list.append(None) + return formal_params_list + + def _list_formal_parameters(self, repeated_submodules) -> dict: + """ + Return a list of formal parameters in each submodule. + + Args: + repeated_submodules (dict): A dict which contains repeated submodules, + acquire this dict from _list_repeated_submodules() + + Return: + OrderedDict, a dict with each submodule's formal args. + + Example: + A return for ResNet50 could be: + + {0: # submoodule 0 + [set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal + set('num_features'), # args of the second node to be set as formal + None, # args of third node to be set as formal, which does not have + set('in_channels', 'out_channels'), + set('num_features'), + None + ]}, + {3: # submodule 3 + [...], + {5: # submodule 5 + []} # empty returns means no nodes or it's a parent module of submodules. + } + """ + formal_args_in_each_submodule = OrderedDict() + checked_module = set() + # filter module_map by submodule_num (without depth) + for _, module_nums in repeated_submodules.items(): + for module_num in module_nums: + if module_num in checked_module: # module already checked + continue + else: + checked_module.add(module_num) + map_filtered = self.module_map_filter(module_num=module_num) + formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) + formal_args_in_each_submodule[module_num] = formal_args_in_this_module + return formal_args_in_each_submodule + + def _add_submodule_to_parent(self): + """ + Recursively add all submodule to its parent module until Main module. + + Note: + This function deepcopy the first node of the submodule, and reset its params as parent module. + """ + depth = self._module_depth_max + while depth > 0: + for (scope_path_str, md_struct) in self.module_structs.copy().items(): + if scope_path_str == '[]': + continue # is main module, skip + if md_struct.scope_depth != depth: + continue # skip all submodules not at current depth + md_struct_scope = copy.deepcopy(md_struct.identifier) + md_struct_scope.pop() + parent_scope = md_struct_scope + # 1. check if this module has parent module + parent_md_struct = self.module_structs.get(str(parent_scope)) + if parent_md_struct is not None: + # 1A. has parent, directly add md_struct to its parent ModuleStruct. + parent_md_struct.add_submodule(md_struct) + self.module_structs[str(parent_scope)] = parent_md_struct + else: + # 1B. not has parent, generate a new ModuleStruct + parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module + # rewrite parent md struct + parent_md_struct.reset_as_parent() + parent_md_struct.add_submodule(md_struct) + self.module_structs[str(parent_scope)] = parent_md_struct + sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections + self._global_context.add_module_struct(sub.pattern_id, sub) + depth -= 1 + + def _recursive_form_module(self): + """Main routine in generator to build modules from bottom to top.""" + # 1. List repeated submodules + repeated_submodules = self._list_repeated_submodules() + # 2. List reused parameters + formal_parameters = self._list_formal_parameters(repeated_submodules) + # 3. Build base subdmodules and set in/ext params translation + for module_struct in self.module_structs.values(): + if module_struct.pattern_id == -1: # is main module + continue + formal_args = formal_parameters.get(module_struct.pattern_id) + module_struct.update_args_translation_list(formal_args) + + # 4. Form parent modules + md_collection_len = len(self.module_structs.keys()) + len_changes = True + while len_changes: + self._add_submodule_to_parent() + new_len = len(self.module_structs.keys()) + if md_collection_len != new_len: + md_collection_len = new_len + else: + len_changes = False + + # 5. Update all translated args from module map + self._update_all_modules_args_translator() + + # 6. Update all nodes and moudles input/output + self.module_structs.get('[]').allocate_construct_header_x() + self.module_structs.get('[]').collect_returns() + + def _update_all_modules_args_translator(self): + """Update all modules' args translators.""" + done_submodule = set() + for depth in range(self._module_depth_max, 0, -1): + # check modules from bottom to top + repeated_submodules = copy.deepcopy(self._repeated_submodules) + repeated_modules = repeated_submodules.get(depth) + if depth is None: + continue + for pattern_id in repeated_modules: + if pattern_id in done_submodule: + continue + # get all md_structs by same pattern + md_list = self._global_context.module_structs.get(pattern_id) + self._take_formal_args_from_updated_submodules(md_list) + args_translators = self.get_args_translator_from_module_structs_list(md_list) + formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators) + changed_args_translators = self.get_args_translator_from_module_structs_list( + md_list, exclude_root_son=True) + ArgsTranslationHelper.change_args_to_formal_for_all_translators( + formal_args_list, changed_args_translators) + done_submodule.add(pattern_id) + + def _take_formal_args_from_updated_submodules(self, md_list): + """ + Take formal args from provided modules' nodes and submodules. + + Args: + md_list (list): A list of ModuleStruct. + """ + if isinstance(md_list, ModuleStruct): + md_list = [md_list] + + for md in md_list: + md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators()) + + def _update_module_depth_max(self, nd_struct: NodeStruct): + """ + Update the Generator attribute module_depth_max, which is the maximum depth in the Model. + + Args: + nd_struct (NodeStruct): NodeStruct to be checked its depth. + """ + depth = nd_struct.scope.depth + if isinstance(depth, int): + if depth > self._module_depth_max: + self._module_depth_max = depth + else: + raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") + + def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None): + """ + Add Node information to the generator. + + Args: + node_identifier (str): The unique identifier for the node passed in. + node_instance (GraphNode): The GraphNode instance of each node. + node_fragment (NodeFragment): The NodeFragment instance of this node passed in. + mapper_dict (dict): The dict contains converted attributes from mapper. + """ + + if node_identifier is None: + raise ValueError("Node Identifier should not be None.") + self._global_context.node_fragments[node_identifier] = node_fragment + args = [] + if node_instance is not None: + args.append(node_instance) + if mapper_dict is not None: + args.append(mapper_dict) + if node_fragment is not None: + args.append(node_fragment) + + nd_struct = self.node_structs.get(node_identifier) + if nd_struct: # NodeStruct already exists + nd_struct.update(args) + else: # create new Node Struct + nd_struct = NodeStruct(args) + nd_struct.identifier = node_identifier + self._update_module_depth_max(nd_struct) + self.node_structs[node_identifier] = nd_struct + + @property + def node_structs(self): + """Return all NodeStructs in this model.""" + return self._node_struct_collections + + @property + def module_structs(self): + """Return all ModuleStructs in this model.""" + return self._module_struct_collections + + def generate(self): + """ + Generate the final script file. + + Returns: + list, a list of each line in script file. + """ + self._form_bottom_submodule() + self._recursive_form_module() + code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) + return code.code_line_list + + def get_node_struct(self, node_identifier): + """ + Get specific NodeStruct by node_identifier. + + Args: + node_identifier (str): The node unique identifier. + + Return: + NodeStruct, the node's NodeStruct. + """ + return self._node_struct_collections.get(node_identifier, None) + + def get_module_struct(self, module_identifier): + """ + Get specific ModuleStruct by module_identifier. + + Args: + module_identifier (str): The module unique identifier. + + Return: + ModuleStruct, the node's ModuleStruct. + """ + return self._module_struct_collections.get(module_identifier, None) + + def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): + """ + Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. + + Args: + pattern_id (int): The pattern id the returned ModuleSturct is. + under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. + + Returns: + list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. + """ + if not pattern_id: + raise ValueError("pattern_id is necessary to get the module struct.") + if not under_parent_pattern_id: + raise ValueError("under_parent_pattern_id is necessary to get the module struct.") + ret = [] + md_list = self._global_context.module_structs.get(pattern_id) + for md in md_list: + if md.parent_id == under_parent_pattern_id: + ret.append(md) + return ret + + def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): + """ + Return a list of args translators which belongs to given module structs. + + Args: + md_list (list): A list of ModuleStruct. + exclude_root_son (Bool): If the returned result should include args translator belongs to + modules under the Main module. + + Returns: + list, a list of args translators which belongs to given module structs. + """ + ret = [] + for md in md_list: + if exclude_root_son and md.parent_id == -1: + continue + if md.args_translator is not None: + ret.append(md.args_translator) + + return ret + + def module_map_filter(self, depth=None, module_num=None, uid=None): + """ + Filter the module map by given conditions. + + Args: + depth (int): Scope depth. + module_num (int): The submodule number. + uid (int): The unique identifier of a submodule. + + Return: + list, list of NodeStruct list of each submodule. + """ + ret = list() + for scope_path, nd_struct_list in self._module_map.items(): + path = Scope.path_str_to_list(scope_path) + if not path: # skip main + continue + + # if depth not equals to the indicated depth, skip + if depth is not None and len(path) != depth: + continue + + scope_at_depth = path[-1] + (m_num, m_uid) = scope_at_depth + if uid is not None: + if m_num == module_num and m_uid == uid: + ret.append(nd_struct_list) + else: + if m_num == module_num: + ret.append(nd_struct_list) + return ret diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py new file mode 100644 index 00000000..ca0a0a90 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -0,0 +1,710 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define a struct for module converted and save all required information here.""" + +from collections import OrderedDict + +from .node_struct import NodeStruct +from .scope_utils import Scope +from ..common.utils import get_dict_key_by_value +from .args_translator import ArgsTranslation +from ..common.code_fragment import ModuleFragment +from ..common.global_context import GlobalContext +from ..hierarchical_tree.name_mgr import LocalVarNameMgr + + +class ModuleStruct: + """ + Define a module struct which stores all info. to generate statement. + + Args: + args (list): A list of node structs. + """ + GLOBAL_CONTEXT_MGR = GlobalContext() + + def __init__(self, nd_struct_list): + """Init. a module by NodeStructs.""" + self.pattern_id = -1 # pattern num, -1 as Main module + self.pattern_uid = -1 # unique module id for this pattern + self.parent_id = None # parent's pattern num + self.parent_uid = None # parent's pattern module unique id + self.initialized = False + self.identifier = None + self.module_name = None + self.scope_depth = None + self.head_nd_struct = None + self.head_nd_struct_index = None + self.tail_nd_struct = None + self.tail_nd_struct_index = None + self._node_structs = list() + self._module_structs = list() + + self._fragment = None + self._args_translator = None + self._setting = None + self._parent_module_struct = None + # only store original formal args name, not global + self._nodes_structs_formal_args_list = list() + # only store translated (globalized) formal args + self._nodes_structs_formal_args_translated_list = list() + + # define other settings here + self._node_args_translation_list = list() + self._var_name_mgr = LocalVarNameMgr() + self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name + self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct + self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name + + # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) + self.outputs_collection = dict() + self.matched_inputs = list() # Matched inputs will can be directly used by code line generation + + # key is ext. succ node onnx name, value is local opt_var + self.external_successor_local_returns_map = OrderedDict() + + # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level + self.outputs_collection = dict() + + # start initialization + if not self.initialized: + self._init_module(nd_struct_list) + else: + self._update_module(nd_struct_list) + + # assign this module reference to node + for (_, nd_struct) in nd_struct_list: + nd_struct.parent_module_struct = self + + def reset_as_parent(self): + """ + Reset all attributes and filled as a parent module of this module. + + Note: + This function must be called only after a deepcopy of this instance! + """ + self.identifier.pop() + self.scope_depth = self.scope_depth - 1 + self._set_pattern_id() + self._find_parent_module() + self.module_name = Scope.scope_to_module_name(self.identifier) + self._node_structs = list() + self._module_structs = list() + self._fragment = None + self._args_translator = None + self.init_args_translator() + self._setting = None + self._parent_module_struct = None + self._nodes_structs_formal_args_list = list() + + self._node_args_translation_list = list() + + def _set_pattern_id(self): + """Set pattern id which matches the module fragment pattern.""" + if not self.initialized: + return + if self.scope_depth < 1: + self.pattern_id = -1 + self.pattern_uid = -1 + return + self.pattern_id = self.identifier[-1][0] + self.pattern_uid = self.identifier[-1][1] + + def _init_module(self, nd_struct_list): + """Init this ModuleStruct by a list of Nodes.""" + (nd_topo_idx, nd_struct) = nd_struct_list[0] + self.identifier = nd_struct.scope.path + self.module_name = nd_struct.scope.to_str + self.scope_depth = nd_struct.scope.depth + self.head_nd_struct = nd_struct + self.head_nd_struct_index = nd_topo_idx + self.tail_nd_struct = nd_struct_list[-1][1] + self.tail_nd_struct_index = nd_struct_list[-1][0] + self._node_structs = nd_struct_list + self.initialized = True + self._set_pattern_id() + self._find_parent_module() + self.init_args_translator() + + def _update_module(self, nd_struct_list): + """Update the ModuleStruct attributes from a list of Nodes.""" + (nd_topo_idx_head, nd_struct_head) = nd_struct_list[0] + (nd_topo_idx_tail, nd_struct_tail) = nd_struct_list[-1] + if self.identifier != nd_struct_head.scope.path: + raise ValueError("Unable to update this module struct {} due to different identifier {}".format( + self.identifier, nd_struct_head.scope.path)) + if nd_topo_idx_head < self.head_nd_struct_index: + self.head_nd_struct_index = nd_topo_idx_head + self.head_nd_struct = nd_struct_head + if nd_topo_idx_tail > self.tail_nd_struct_index: + self.tail_nd_struct_index = nd_topo_idx_tail + self.tail_nd_struct = nd_struct_tail + self._node_structs += nd_struct_list + + def _find_parent_module(self): + """Set the parent's module pattern and uid.""" + if not self.initialized: + return + if self.scope_depth == 0: # is Main Module + pass + elif self.scope_depth == 1: # parent pattern is Main module + self.parent_id = -1 + self.parent_uid = -1 + else: # this is a submodule in a module + (self.parent_id, self.parent_uid) = Scope.get_parent_module_num_and_uid( + self.identifier) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "identifier": self.identifier, + "parent": (self.parent_id, self.parent_uid), + "name": self.module_name, + "pattern": self.pattern_id, + "scope_depth": self.scope_depth, + "nd_idx_range": "{} -> {}".format(self.head_nd_struct_index, self.tail_nd_struct_index), + "initialized": self.initialized + }) + + def init_module_fragment(self): + """Init the module fragment.""" + if not self.initialized: + return + # check if fragment exists in global context + op = "Module{}".format(self.pattern_id) + if op == "Module-1": # reset as Main Model's op name + op = "Model" + frag = GlobalContext().get_module_fragment(op) + if frag is not None: # use exists fragment + self._fragment = frag + else: + frag = ModuleFragment(operation=op, + actual_args=None, + input_shape=None, + output_shape=None, + settings=None) + self._fragment = frag + # set fragment pattern + self._fragment.pattern = self._node_structs + GlobalContext().add_module_fragment(op, frag) + + def init_args_translator(self): + """Initialize the Args Translator for the module.""" + var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) + self._args_translator = ArgsTranslation(None, var_name, None) + + def update_module_fragment(self): + """Update this module's fragment.""" + if self._fragment is None: + return + + # update input output shape + self._fragment.input_shape = self.head_nd_struct.fragment.input_shape + self._fragment.output_shape = self.tail_nd_struct.fragment.output_shape + + # update formal args + self._fragment.formal_args.update(self._args_translator.formal_args) + self._fragment.formal_args_value.update(self._args_translator.formal_args_values) + # update actual args + self._fragment.actual_args.update(self._args_translator.actual_args) + # update others.. + + def add_submodule(self, md_structs): + """ + Add another module struct(s) to this ModuleStruct. + + Args: + md_structs ([ModuleStruct, list]): a (list) ModuleStruct to be added in this ModuleStruct. + """ + tail_md = md_structs + if isinstance(md_structs, ModuleStruct): + md_structs.args_translator.take_formal_args_from_nodes_and_submodules(md_structs.get_all_sub_translators()) + self._module_structs.append(md_structs) + md_structs.parent_module_struct = self + elif isinstance(md_structs, list): + for md_s in md_structs: + md_s.args_translator.take_formal_args_from_nodes_and_submodules(md_s.get_all_sub_translators()) + md_s.parent_module_struct = self + self._module_structs += md_structs + tail_md = md_structs[-1] + else: + raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( + type(md_structs))) + # update tail node and index + if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: + self.tail_nd_struct = tail_md.tail_nd_struct + self.tail_nd_struct_index = tail_md.tail_nd_struct_index + + def _update_formal_args_for_all_nd_structs(self): + """ + Init nodes' args translator and find formal args. + And collect nodes' formal args. + """ + if len(self._node_args_translation_list) != len(self._node_structs): + raise ValueError( + "ModuleStruct cannot update nodes' formal args due to length inconsistent.") + for idx, (_, nd_struct) in enumerate(self._node_structs): + formal_arg_of_this_node = self._node_args_translation_list[idx] + # update var_name to ensure all node names' are unique in a module. + nd_struct.update_var_name(idx) + nd_struct.init_args_translator(formal_arg_of_this_node) + if nd_struct.args_translator is not None: + self._nodes_structs_formal_args_list.append( + nd_struct.args_translator.formal_args_values) + else: + self._nodes_structs_formal_args_list.append(None) + + def update_args_translation_list(self, formal_args): + """ + Receive a list of args name to be changed to formal args, and change them. + + Args: + formal_args (list[str]): a list of args name to be changed to formal args. + """ + self._node_args_translation_list = formal_args + self._update_formal_args_for_all_nd_structs() + + def get_all_sub_translators(self): + """ + Return a list of args_translators of submodules / nodes affiliated to this module. + + Note: + The order of returned list is followed by the actual topological order. + + Returns: + list, a list of args_translators. + """ + ret = [] + for (_, struct) in self.get_generate_order(): + if struct.args_translator is not None: + ret.append(struct.args_translator) + return ret + + def get_generate_order(self): + """ + Return the order of generated code by index. + + Return: + list, a list of reference of node_struct or module_struct. + """ + ret = list() + if not self._module_structs: + return self._node_structs + # Generate a list of tuple (idx, module_structs) + for md_struct in self._module_structs: + ret.append((md_struct.head_nd_struct_index, md_struct)) + if self.node_structs: + ret += self.node_structs + ret.sort(key=lambda x: x[0]) + return ret + + def code_line_in_init(self): + """ + Initialization line of code in module init block. + + Args: + override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. + """ + left = "self.{}".format(self.ms_var_name) + args_list = list() + # Load args in init statement. + if self._args_translator is not None: # from args_translator + if self._args_translator.actual_args: # load actual args + args_list += self._args_translator.actual_args_to_str_list + elif self._args_translator.actual_args_backup and self.parent_id == -1: + # For modules repeated in multiple levels, the module under main model should + # not use formal args as it is unnecessary -> load from actual args backup + args_list += self._args_translator.actual_args_backup_to_str_list + args_list += self._args_translator.formal_args_to_str_list # load from formal args + else: + args_list += self._fragment.actual_args + right = f"{self.class_name}({', '.join(args_list)})" + return (left, right) + + def code_line_in_construct(self, inputs=None): + """Construct line of code in module construct block.""" + # check number of outputs this module has + opt_var_name_in_module = list(self.external_successor_local_returns_map.values()) + num_output = len(set(opt_var_name_in_module)) + if num_output == 1: # single output + left = f"{self.ms_opt_var_name}" + else: + left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)] + + if inputs is None and self.matched_inputs: + inputs = self.matched_inputs + + if isinstance(inputs, str): + inputs = [inputs] + right = f"self.{self.ms_var_name}({', '.join(inputs)})" + return (left, right) + + @property + def node_structs(self): + """Return all node structs in this module.""" + return self._node_structs + + @property + def module_structs(self): + """Return all module structs in this module.""" + return self._module_structs + + @property + def parent_module_struct(self): + """Return this module's parent module struct.""" + return self._parent_module_struct + + @parent_module_struct.setter + def parent_module_struct(self, ref): + """Set this modu;e's parent module struct.""" + self._parent_module_struct = ref + + @property + def args_translator(self): + """Return the args translator.""" + return self._args_translator + + @property + def head_nd_struct_precursor_nodes_names(self) -> list: + """Return head node's precursor nodes names.""" + return self.head_nd_struct.precursor_nodes_names + + @property + def head_nd_struct_precursor_nodes_structs(self) -> list: + """Return head node's precursor nodes structs.""" + return self.head_nd_struct.precursor_nodes_structs + + @property + def tail_nd_struct_successor_nodes_names(self) -> list: + """Return tail node's successor nodes names.""" + return self.tail_nd_struct.successor_nodes_names + + @property + def tail_nd_struct_successor_nodes_structs(self) -> list: + """Return tail node's successor nodes structs.""" + return self.tail_nd_struct.successor_nodes_structs + + @property + def onnx_names_from_nodes(self) -> list: + """Return all nodes onnx names in this module.""" + ret = [] + for (_, node) in self.node_structs: + ret.append(node.onnx_name) + return ret + + @property + def onnx_names_from_submodules(self) -> list: + """Return all nodes onnx names in submodules of this module.""" + ret = [] + for md_struct in self.module_structs: + ret += md_struct.onnx_names + return ret + + @property + def onnx_names(self) -> list: + """Return all nodes' onnx names which contained by this module.""" + return self.onnx_names_from_nodes + self.onnx_names_from_submodules + + @property + def external_precursor_nodes_names(self) -> list: + """Return all precursors nodes names not in this module.""" + ret = [] + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): + precursor_nodes_names = struct.precursor_nodes_names + + if isinstance(struct, ModuleStruct): + precursor_nodes_names = struct.external_precursor_nodes_names + + for p_name in precursor_nodes_names: + if p_name in self.onnx_names: + continue + ret.append(p_name) + return ret + + @property + def external_successor_nodes_names(self) -> list: + """Return all precursors nodes names not in this module.""" + ret = [] + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): + successor_nodes_names = struct.successor_nodes_names + + if isinstance(struct, ModuleStruct): + successor_nodes_names = struct.external_successor_nodes_names + + for s_name in successor_nodes_names: + if s_name in self.onnx_names: + continue + ret.append(s_name) + return ret + + @property + def class_name(self) -> str: + """Return the class name for generating code of this module.""" + if self.pattern_id == -1: + return "Model" + return "Module{}".format(self.pattern_id) + + @property + def ms_var_name(self) -> str: + """Return the variable name for generated code statement of this module.""" + if self.pattern_id == -1: + return "Model" + return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() + + @property + def ms_opt_var_name(self) -> str: + """Return the variable name for generated code statement of the output of this module.""" + return "{}_opt".format(self.ms_var_name).lower() + + # The following part will be resetting nodes' external inputs for supporting multi-in/out + # and should be called after generator.recursive_form_modules() + + def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): + """ + Mark the registered external inputs for code generation. + + Note: + This function to be called by its parent (ModuleStruct). + + Args: + header_x (str): The `x` in module construct header. + onnx_precursor_node_name (str): The original onnx node name. + """ + if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: + raise ValueError("The input from {} has already registered. Check this Module \ + {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) + self.inputs_in_construct_header[onnx_precursor_node_name] = header_x + + def allocate_construct_header_x(self, force_x=None): + """ + Allocate the x in construct header for each external input. + + Args: + force_x (str): Force the arg name to customized. + """ + local_x_name = 'x' + if force_x: # name of x indicated by external + local_x_name = force_x + + # set construct_header_x for current module + allocated = set() + for prec_name in self.external_precursor_nodes_names: + if prec_name in allocated: + continue + x_name_in_construct_header = self._var_name_mgr.get_name(local_x_name) + self.construct_header_x[x_name_in_construct_header] = prec_name + allocated.add(prec_name) + + # Assign these inputs to nodes and submodules + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): # register node's ext input + self.reset_node_external_input_to_local(struct) + self.register_node_output_to_module(struct) + if isinstance(struct, ModuleStruct): # reg module's ext input + if not struct.construct_header_x: + struct.allocate_construct_header_x() + self.reset_submodule_external_input_to_local(struct) + self.register_submodule_output_to_module(struct) + + # remove parent module's ext. map if ext nodes in this module (no need return) + for user_name in self.external_successor_local_returns_map.copy().keys(): + if user_name in self.onnx_names: + self.external_successor_local_returns_map.pop(user_name) + + def _match_node_inputs(self, struct): + """Match node's inputs with its precursor nodes.""" + for output_provider in struct.precursor_nodes_names: + output_list = self.outputs_collection.get(output_provider) + if output_list is None: + # not in this module, check construct header + for (self_x_name, self_output_provider) in self.construct_header_x.items(): + if self_output_provider == output_provider: + struct.matched_inputs.append(self_x_name) + continue + for output in output_list: + (provider_succ, provider_closet_opt_var) = output + if provider_closet_opt_var in struct.matched_inputs: + continue # skip repeat + if provider_succ == struct.onnx_name: + struct.matched_inputs.append(provider_closet_opt_var) + + def _match_sub_modules_inputs(self): + """ + Match current module's submodules' inputs with corresponding outputs registered in current module. + + Description: + The function matches these inputs by the following steps: + 1. For each submodule in the current module, take submodule's construct header + 2. Check submodule's construct header element requires an input from current module's + construct header or outputs from other submodules. + 3. If from current module's construct header, assign corresponding x to the submodule. + If from other submodules, assign required submodule output name to the submodule. + """ + if not self.outputs_collection: + return # skip first node + for (_, struct) in self.get_generate_order(): + if isinstance(struct, NodeStruct): + self._match_node_inputs(struct) + continue # skip node + sub_construct_header = struct.construct_header_x + for (_, output_provider) in sub_construct_header.items(): + # check from outputs collection + output_list = self.outputs_collection.get(output_provider) + if output_list is None: + # not in this module, need from current module construct header + for (self_x_name, self_output_provider) in self.construct_header_x.items(): + if self_output_provider == output_provider: + struct.matched_inputs.append(self_x_name) + continue + for output in output_list: + (provider_succ, provider_closet_opt_var) = output + if provider_closet_opt_var in struct.matched_inputs: + continue # skip repeat + if provider_succ in struct.onnx_names: + struct.matched_inputs.append(provider_closet_opt_var) + + def _append_to_outputs_collection(self, provider_name, val): + """ + Helper function to add a nodes or submodules outputs to current module return statement. + + Args: + provider_name (str): The onnx name of the output provider. + val (list[tuple]): A list of tuple which contains + the output provider's successor name and its opt_var_name. + """ + exist_output = self.outputs_collection.get(provider_name) + if isinstance(val, tuple): + val = [val] + if exist_output is None: # add new entry + exist_output = list() + exist_output += (val) + self.outputs_collection[provider_name] = exist_output + + def collect_returns(self): + """ + Collect all nodes and submodules' returns in the module. + + Note: + The logic is to collect the return from nodes and submodules by the order + of topological index. + + For returns from a node, it will check if the return will be used externally. + If external (external means the successor a.k.a the return user has different scope with the node), + add this return to current module's outputs_collection, where + key is this node's original onnx_name, and value is a list of + tuple(successor_name, this node's opt_var_name) + + For returns from a submodule, it will check if the submodule has already collected returns, + If not, do it and then continue the following procedures. + Now we will check each element in submodule's outputs_collection. Note that we DO NOT check submodule's + returns should be continued returning, but just return them. + All these returns from submodules will be changes their original nodes output (a.k.a outputs provider) + `opt_var_name` to submodules' `opt_var_name`. + + Finally, we match the outputs and inputs in the current module level. + """ + for (_, struct) in self.get_generate_order(): + if isinstance(struct, NodeStruct): + outputs_list = [] + # add these successor nodes name to collection for future use + for succ in struct.successor_nodes_names: + outputs_list.append((succ, struct.ms_opt_var_name)) + if outputs_list: + self._append_to_outputs_collection(struct.onnx_name, outputs_list) + if isinstance(struct, ModuleStruct): + # Remove unnecessary returns, succ are all inside current + if not struct.outputs_collection: + struct.collect_returns() + sub_outputs_collection = struct.outputs_collection + # check each returns in sub + for provider_name, outputs_list in sub_outputs_collection.items(): + for output in outputs_list: + (succ, _) = output # (succ, provider_opt_var_name) in output + new_output = (succ, struct.ms_opt_var_name) + self._append_to_outputs_collection(provider_name, new_output) + self._match_sub_modules_inputs() + + def get_returned_opt_var_name(self) -> list: + """Return a list of returned output var of this module.""" + idx = 0 + added_to_return = set() + ret = [] + for ext_successor_requested, opt_var_name_in_this_module in self.external_successor_local_returns_map.items(): + if ext_successor_requested in added_to_return: + continue + ret.append((ext_successor_requested, opt_var_name_in_this_module, idx)) + added_to_return.add(ext_successor_requested) + return ret + + def reset_node_external_input_to_local(self, nd_struct): + """ + Reset node's input to module's construct args + """ + for prec_node_name in nd_struct.precursor_nodes_names_external: + if prec_node_name in self.onnx_names: # prec node in current module's. + continue + if prec_node_name in self.construct_header_x.values(): + # prec node assigned to construct header to passed in. + local_x = get_dict_key_by_value(prec_node_name, self.construct_header_x) + nd_struct.set_inputs_in_construct_header(local_x, prec_node_name) + else: # Extra precursor nodes, raise error + raise ValueError("Found external inputs of the Node but the module does not have it.") + + def reset_submodule_external_input_to_local(self, md_struct): + """ + Reset submodule's external input to current module. + + Args: + md_struct (ModuleStruct): The submodule in the current module. + """ + # check submodule's input + for _, submodule_precursor in md_struct.construct_header_x.items(): + if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return + # but do nothing here + continue + else: # if external, match with current module construct header x + if submodule_precursor in self.construct_header_x.values(): + local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) + md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) + else: # Extra precursor nodes, raise error + raise ValueError("Found external inputs of the submodule but the module does not have it.") + + def register_node_output_to_module(self, nd_struct): + """Register nodes outputs to this module's return.""" + for succ_node_name in nd_struct.successor_nodes_names_external: + self.external_successor_local_returns_map[succ_node_name] = nd_struct.ms_opt_var_name + + def register_submodule_output_to_module(self, md_struct): + """Register submodule outputs to this module's return.""" + submodule_returns = md_struct.get_returned_opt_var_name() + submodule_opt_var_name = md_struct.ms_opt_var_name + for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: + self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) + # edit external succ 's inputs in parent module + ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) + ext_node_parent = ext_node.parent_module_struct + while ext_node_parent != self.parent_module_struct: + ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name + ext_node_parent = ext_node_parent.parent_module_struct + + # need find the prec_name? + for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): + if isinstance(opt_var_name, str): + if opt_var_name == opt_var_name_in_this_module: + ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) + if isinstance(opt_var_name, tuple): + if opt_var_name[0] == opt_var_name_in_this_module: + ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py new file mode 100644 index 00000000..9764fe2f --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -0,0 +1,423 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define the NodeStruct which stores all info. of a node.""" +from collections import OrderedDict + +from .scope_utils import Scope +from .args_translator import ArgsTranslation +from ..common.code_fragment import CodeFragment +from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode +from ..third_party_graph.onnx_graph_node import OnnxGraphNode +from ..common.global_context import GlobalContext +from ..constant import InputType + + +class NodeStruct: + """ + Define a node struct which stores all info. to generate statement. + + Args: + args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. + + Note: + You can pass as many args as possible and the Node Struct will update + by arguments order. + """ + GLOBAL_CONTEXT_MGR = GlobalContext() + + def __init__(self, args): + # define attributes here + self._identifier = None + self._fragment = None + self._args_translator = None + self._parent_module_struct = None + self.topo_idx = None + self.node_type = None + self.onnx_name = None + self.onnx_op = None + self.graph_node_ref = None # Our defined GraphNode + self.scope_name = None + self.ms_var_name = None + self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) + self.ms_op = None + self.ready_to_generate = False + + self.ms_params = dict() # converted params from mapper + self.ms_settings = dict() + self.ms_weights = dict() + self.ms_inputs = OrderedDict() + + self.scope = None # Defined Scope class + self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use + self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name + self.matched_inputs = list() # Matched inputs will can be directly used by code line generation + + # initialize funcs. + for arg in args: + self.update(arg) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "idx": self.topo_idx, + "identifier": self.identifier + }) + + def ori_topo_idx(self): + """Get the original topological index in the onnx graph.""" + ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') + self.onnx_name = ori_name + return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) + + def update_var_name(self, idx=None): + """ + Update the var_name of each node. + + Args: + idx (int): The index of the node in this module. + """ + if idx is not None: + self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) + elif self.topo_idx is not None: + self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) + else: + raise ValueError("Unable to update var name when topo_idx is None.") + self.ms_opt_var_name = self.ms_var_name + '_opt' + + def _update_basics_from_gn(self, gn): + """Update basic info from GraphNode.""" + self.graph_node_ref = gn + self.scope_name = gn.scope_name + + def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): + """Update basic info from PyTorchGraphNode.""" + self.node_type = "PyTorchGraphNode" + self._update_basics_from_gn(gn) + + def _update_from_onnx_gn(self, gn: OnnxGraphNode): + """Update basic info from OnnxGraphNode.""" + self.node_type = "OnnxGraphNode" + self._update_basics_from_gn(gn) + + def _update_from_mapper(self, d): + """Update info from mapper.""" + if d.get('op_name'): + self.ms_op = d.get('op_name') + if d.get('params'): + self.ms_params = d.get('params') + if d.get('settings'): + self.ms_settings = d.get('settings') + if d.get('weights'): + self.ms_weights = d.get('weights') + + def _update_from_fragment(self, frag: CodeFragment): + """Update info from CodeFragment.""" + self._fragment = frag + if frag.operation: + self.ms_op = frag.operation + idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count + self.update_var_name(idx=idx) + + def _set_scope_from_identifier(self): + """Set the Node scope from identifier.""" + parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) + self.scope = Scope(parsed_scope) + + def init_args_translator(self, translated_args: list): + """ + Initialize the ArgsTranslator for each Node. + + Args: + translated_args (list): The list of args should be translated to formal args. + """ + if not self._fragment: + raise ValueError("Initialize argument translator failed.") + if self._fragment.actual_args and translated_args: + self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) + + def check_if_generate_ready(self): + """Check if the NodeStruct is able to generate code.""" + # check essential params exists + if all([self.identifier, + self.node_type, + self.scope_name, + self.ms_var_name, + self.ms_opt_var_name, + self.ms_op]): + self.ready_to_generate = True + + def update(self, arg, force_ready=False): + """ + Pass Node info. to generator NodeStruct. + + Args: + arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. + force_ready (bool): Force this NodeStruct is ready to generate. + """ + if isinstance(arg, PyTorchGraphNode): + self._update_from_pytorch_gn(arg) + elif isinstance(arg, OnnxGraphNode): + self._update_from_onnx_gn(arg) + elif isinstance(arg, (dict, OrderedDict)): + self._update_from_mapper(arg) + elif isinstance(arg, CodeFragment): + self._update_from_fragment(arg) + else: + raise TypeError("NodeStruct received an unsupported initializing argument.") + + if force_ready: + self.ready_to_generate = True + else: + self.check_if_generate_ready() + + @property + def identifier(self): + """Return the identifier of the node.""" + return self._identifier + + @identifier.setter + def identifier(self, s): + """ + Set the Node identifier, and update the scope. + + Args: + s (str): The node identifier string. + """ + self._identifier = s + self._set_scope_from_identifier() + self.topo_idx = self.ori_topo_idx() + self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self + + @property + def fragment(self): + """Return the fragment of the node.""" + return self._fragment + + @fragment.setter + def fragment(self, frag): + """ + Set the Node fragment. + + Args: + s (NodeFragment): The node identifier string. + """ + self._fragment = frag + + @property + def graph_node(self): + """Return the GraphNode reference.""" + return self.graph_node_ref + + @graph_node.setter + def graph_node(self, graphnode): + """Set the GraphNode reference.""" + self.graph_node_ref = graphnode + + @property + def onnx_node(self): + """Return the original onnx node reference.""" + return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) + + @property + def args_translator(self): + """Return the args translator of this Node.""" + return self._args_translator + + @property + def precursor_nodes_names(self) -> list: + """Return the names of precursor nodes.""" + return self.graph_node_ref.precursor_nodes + + @property + def precursor_nodes_structs(self) -> list: + """Return the node struct instances of precursor nodes.""" + ret = [] + precursor_nodes_names = self.precursor_nodes_names + for pre_node_name in precursor_nodes_names: + nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + ret.append(nd_struct) + return ret + + @property + def successor_nodes_names(self) -> list: + """Return the names of successor nodes.""" + return self.graph_node_ref.successor_nodes + + @property + def successor_nodes_structs(self) -> list: + """Return the node struct instances of successor nodes.""" + ret = [] + for pre_node_name in self.successor_nodes_names: + nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + ret.append(nd_struct) + return ret + + @property + def parent_module_struct(self): + """Return the parent struct of this node.""" + return self._parent_module_struct + + @parent_module_struct.setter + def parent_module_struct(self, ref): + self._parent_module_struct = ref + + # Code Generation funcs below + + def code_line_in_init(self): + """Initialization line of code in module init block.""" + left = "self.{}".format(self.ms_var_name) + args_list = list() + if self._args_translator is not None: + args_list += self._args_translator.actual_args_to_str_list + args_list += self._args_translator.formal_args_to_str_list + else: + actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) + args_list += actual_args_str + right = f"{self.ms_op}({', '.join(args_list)})" + return left, right + + def _get_correct_in_module_returns(self, prec_node, in_module_return): + """ + Find the correct precursor node name in return statement of its parent module. + + Args: + prec_node (str): The onnx name of the precursor node given. + in_module_return (list[tuple]): The list of outputs which contains parent module identifier + and module opt_var_name. + + Return: + str, correct opt_var_name to be passed in current node. + """ + found_return = False + for ret in in_module_return: + (md_identifier, input_name_to_use) = ret + p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) + # recursive check the p node parent + parent = p_node_struct + while not found_return: + parent = parent.parent_module_struct + if parent is None: + break + if parent.identifier == md_identifier: + return input_name_to_use + return None + + def code_line_in_construct(self, inputs=None, in_module_returns=None): + """Construct line of code in module construct block. """ + left = self.ms_opt_var_name + if inputs is None: + inputs = [] + for idx, prec_node in enumerate(self.precursor_nodes_names): + if self.inputs_in_construct_header.get(prec_node): + inputs.append(self.inputs_in_construct_header.get(prec_node)) + elif self._check_target_node_internal(prec_node): + inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) + elif self.inputs_in_parent_module.get(prec_node): + inputs.append(self.inputs_in_parent_module.get(prec_node)) + elif in_module_returns and in_module_returns.get(self.onnx_name) \ + and (not self._check_target_node_internal(prec_node)): + inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) + else: + inputs.append("unk_{}_{}".format(idx, prec_node)) + + if self.matched_inputs: + inputs = self.matched_inputs + + # Check original onnx node's input to ensure double inputs are not ignored + original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) + new_inputs = [] + for idx, prec_node in enumerate(self.precursor_nodes_names): + occurence = original_inputs.count(prec_node) + for _ in range(occurence): + new_inputs.append(inputs[idx]) + inputs = new_inputs + + if isinstance(inputs, str): + inputs = [inputs] + + if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: + inputs = [str(tuple(inputs)).replace("\'", "")] + + if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: + for _, val in self._fragment.code_setting.op_extra_input.items(): + inputs.append(str(val)) + + if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: + inputs.append(f"self.{self.ms_var_name}_w") + right = f"self.{self.ms_var_name}({', '.join(inputs)})" + return left, right + + def add_extra_tensor(self): + """ Add extra tensor.""" + left = "self.{}_w".format(self.ms_var_name) + shape = self._fragment.code_setting.op_extra_tensor.shape + right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)" + return left, right + + # The following functions are specified for multiple in/out support. + # and should be called only after generator._recursive_form_modules() + + def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): + """ + Mark the registered external inputs for code generation. + + Note: + This function to be called by its parent (ModuleStruct). + + Args: + header_x (str): The `x` in module construct header. + onnx_precursor_node_name (str): The original onnx node name. + """ + if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: + raise ValueError("The input from {} has already registered. Check this node \ + {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) + self.inputs_in_construct_header[onnx_precursor_node_name] = header_x + + def _check_target_node_internal(self, name: str) -> bool: + """ + Check given node under the same scope. + + Args: + name (str): Can accept both node identifier or original onnx node name. + """ + target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ + or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) + if target_nd_struct is None and self.topo_idx == 0: # First node always has external input + return False + + if target_nd_struct is None: + raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) + return target_nd_struct.scope.path == self.scope.path + + @property + def has_successor_node_external(self) -> bool: + """Check if any successor_node is in external module.""" + for name in self.successor_nodes_names: + if not self._check_target_node_internal(name): + return False + + return True + + @property + def precursor_nodes_names_external(self) -> list: + """Return a list of external precursor nodes names.""" + return [name for name in self.precursor_nodes_names \ + if not self._check_target_node_internal(name)] + + @property + def successor_nodes_names_external(self) -> list: + """Return a list of external successor nodes names.""" + return [name for name in self.successor_nodes_names \ + if not self._check_target_node_internal(name)] diff --git a/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py b/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py new file mode 100644 index 00000000..43b8c28b --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define a scope class processing all operations related to scope and scope name.""" +import re + + +class Scope(): + """Define scope related operations.""" + + def __init__(self, scope_str): + scopes = scope_str.split('/') + self.module_path = list() + self.scope_list = scopes[:-1] + self.head = self.scope_list[0] + self.tail = self.scope_list[-1] + self.initialization() + + def initialization(self): + """Init scope class.""" + self._update_module_path_from_scope_list() + + def _update_module_path_from_scope_list(self): + """Update the module scope path from a list of scope.""" + self.module_path = list() + for scope in self.scope_list: + if scope == 'Model': + continue + + if 'Module' in scope: + regex = r"Module(?P\d+)_(?P\d+)" + match = re.match(regex, scope) + if match: + module_num = match.group('num') + uid = match.group('curr_level_unique_id') + self.module_path.append((int(module_num), int(uid))) + + @property + def path(self): + """Return module scope path.""" + return self.module_path + + def set_path(self, ind, path_tuple: tuple): + """ + Set the module scope path. + + Args: + ind (int): The index of the scope path to be set. + path_tuple ((int, int)): The tuple of the scope path. + """ + self.module_path[ind] = path_tuple + + @property + def to_str(self): + """Return the full module scope as the string format.""" + full_str_list = ["Model"] + for (num, uid) in self.module_path: + local = "Module{}_{}".format(num, uid) + full_str_list.append(local) + + return "/".join(full_str_list) + + @property + def depth(self): + """Return the depth of the scope path.""" + return len(self.path) + + @staticmethod + def scope_to_module_name(path): + """ + Helper function to convert any scope path string to the full module scope. + + Args: + path (str): path string like "[(5, 0), (3, 0)]" + + Returns: + str, the full module scope with format like "Model/Module5_0/Module3_0/" + """ + scope_str_list = ["Model"] + if isinstance(path, str): + path = Scope.path_str_to_list(path) + if isinstance(path, list): + for (num, uid) in path: + local = "Module{}_{}".format(num, uid) + scope_str_list.append(local) + + return "/".join(scope_str_list) + + @staticmethod + def parse_scope_from_node_identifier(node_identifier: str): + """ + Helper function to parse the scope string from node identifier. + + Args: + node_identifier (str): The string of the node identifier. + + Returns: + str, parsed scope string from node identifier. + """ + regex = r"(?PModel/.*)\$\S+\$" + match = re.match(regex, node_identifier) + if not match: + return None + return match.group('scope') + + @staticmethod + def path_str_to_list(scope_path_str: str): + """ + Helper function to convert the scope path string back to list. + + Args: + scope_path_str (str): The scope path string like "[(5, 0), (3, 0)]". + + Returns: + list, a list of the scope path like [(5, 0), (3, 0)]. + """ + ret = [] + tmp = scope_path_str.strip('[').strip(']') + regex = r"\((?P\d+), (?P\d+)\)" + s_all = re.findall(regex, tmp) + for (num, uid) in s_all: + ret.append((int(num), int(uid))) + + return ret + + @staticmethod + def get_parent_module_num_and_uid(path): + """ + Helper function to return its parent's scope tuple. + + Args: + path (Union[str, list]): Module scope path string. e.g. "[(5, 0), (3, 0)]" + + Returns: + tuple, parent's scope level. e.g. [(5, 0)] + """ + if isinstance(path, str): + path = Scope.path_str_to_list(path) + if isinstance(path, list): + if len(path) == 1: # modules under the main module, (-1, -1) means main module. + return (-1, -1) + if len(path) > 1: # modules under another non-main module. Return parent's scope. + parent = path[-2] + return parent + + return None diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py index 2811f762..f43e7a0a 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py @@ -106,3 +106,51 @@ class GlobalVarNameMgr: global_var_namespace.add(new_name) return new_name + + +class LocalVarNameMgr: + """Local variable name mgr.""" + + def __init__(self): + self.local_op_namespace = dict() + self.local_var_namespace = set() + + @staticmethod + def _get_name(name): + """Deal with op name.""" + if "::" in name: + return name.split("::")[1] + return name + + def get_name(self, op_type): + """ + Get module/variable name. + + If the module already existed, then add a suffix to it. + + conv1 onnx::conv + + Args: + op_type (str): Operator type in onnx. + + Returns: + str, module name. + """ + + def _gen(t): + t = t.lower() + if t not in self.local_op_namespace: + self.local_op_namespace[t] = START_IDX + suffix = "" + else: + self.local_op_namespace[t] += 1 + suffix = f"{self.local_op_namespace[t] - 1}" + + return f"{self._get_name(t)}{suffix}" + + new_name = _gen(op_type) + while new_name in self.local_var_namespace: + new_name = _gen(op_type) + + self.local_var_namespace.add(new_name) + return new_name diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 047556fa..3ccf088e 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -151,7 +151,7 @@ class OnnxGraph(Graph): input_shape (tuple): Input shape. """ input_node = InputNode(input_shape) - input_node_name = "{}InputNode" + input_node_name = self._raw_input_nodes.replace(":0", "") for node_name, node in self._nodes_collection.items(): if node_name in self._input_nodes: ipt_nd_name = input_node_name.format(input_node.scope_name) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 6ba67c84..f76a371a 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -23,6 +23,7 @@ import numpy as np from mindinsight.mindconverter.common.log import logger as log from ..common.utils import fetch_output_from_onnx_model +from ..common.global_context import GlobalContext from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL @@ -110,6 +111,7 @@ class OnnxTensor: self.to_nodes = [] def to_array(self): + """Convert the tensor value from binary to np array.""" onnx = import_module("onnx") # Convert binary data to np.array if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): @@ -264,7 +266,7 @@ class OnnxDataLoader: self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] # args for init self._is_infer_shape = infer_shape - + self._global_context = GlobalContext() # params parsed in init self.inferred_model = None @@ -375,12 +377,19 @@ class OnnxDataLoader: def _parse_nodes(self): """Parse each onnx nodes in the model.""" - for node in self.nodes: + nodes_topo_idx = [] + for idx, node in enumerate(self.nodes): n = OnnxNode(node) self._nodes_dict[n.name] = n + nodes_topo_idx.append((idx, n.name)) if len(node.output) > 1: raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") self.output_name_to_node_name[node.output[0]] = node.name + self._global_context.onnx_node_name_to_topo_idx[n.name] = idx + node_inputs = [i.replace(":0", "") for i in node.input] + self._global_context.onnx_node_inputs[n.name] = node_inputs + self._global_context.onnx_nodes_collection = self._nodes_dict + self._global_context.onnx_nodes_topo_index = nodes_topo_idx def _parse_tensors(self): """Parse each onnx tensors in the model.""" @@ -388,6 +397,7 @@ class OnnxDataLoader: for tensor in tensors: t = OnnxTensor(tensor) self.tensors_dict[t.name] = t + self._global_context.onnx_tensors_collection = self.tensors_dict def _parse_node_output_shape(self): """