From 03b2867978e8a1c342cd048b0fe7ebf7ea150519 Mon Sep 17 00:00:00 2001 From: liangtianshu Date: Sat, 5 Dec 2020 15:56:44 +0800 Subject: [PATCH] MindConverter adds generator and its required dependencies. Add global context to exchange information among multiple classes to reduce passing arguments through multiple procedures. Add Node and Module Struct to store all information converted for MindSpore script generation Add Args translator and scope utils to help process scope information and operators' arguments Add generator to generate the MindSpore script from information stored in Node and Module struct. --- .../common/global_context.py | 47 +- .../graph_based_converter/common/utils.py | 15 + .../generator/__init__.py | 111 +++ .../generator/args_translator.py | 248 ++++++ .../generator/generator.py | 630 ++++++++++++++++ .../generator/module_struct.py | 710 ++++++++++++++++++ .../generator/node_struct.py | 423 +++++++++++ .../generator/scope_utils.py | 157 ++++ .../hierarchical_tree/name_mgr.py | 48 ++ .../third_party_graph/onnx_graph.py | 2 +- .../third_party_graph/onnx_utils.py | 14 +- 11 files changed, 2395 insertions(+), 10 deletions(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/__init__.py create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/args_translator.py create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/generator.py create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/module_struct.py create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/node_struct.py create mode 100644 mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index c0ce4c71..836c911c 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -40,8 +40,10 @@ class GlobalContext(metaclass=Singleton): # Define data stored from onnx_utils # Key as Onnx Name self._onnx_nodes_collection = OrderedDict() - # key is topo_idx, value is onnx_node_name. + # key is topo_idx, value is onnx_node_name self._onnx_nodes_topo_index = dict() + self.onnx_node_name_to_topo_idx = dict() + self.onnx_node_inputs = dict() self._onnx_tensors_collection = dict() # Define data stored from generator @@ -50,7 +52,7 @@ class GlobalContext(metaclass=Singleton): self.node_struct_adder_counter = 0 # Define onnx_utils <---> generator mapping self.node_struct_to_onnx_node_map = dict() - self.onnx_node_to_node_struct_map = dict() + self.onnx_node_name_to_node_struct_map = dict() # Define Module pattern to customize name mapping self.module_customized_name = dict() @@ -59,6 +61,8 @@ class GlobalContext(metaclass=Singleton): self.node_fragments = OrderedDict() self.module_fragments = OrderedDict() + # Define Known module mapping + self.known_module_name = dict() # Define Structs # key is pattern_id, value is [ModuleStructs] self.module_structs = dict() @@ -83,7 +87,7 @@ class GlobalContext(metaclass=Singleton): def get_identifier_from_onnx_node_name(self, node_name): """Return the node identifier by Onnx Node name.""" - identifier = self.onnx_node_to_node_struct_map.get(node_name) + identifier = self.onnx_node_name_to_node_struct_map.get(node_name) return identifier @property @@ -98,9 +102,7 @@ class GlobalContext(metaclass=Singleton): @onnx_nodes_collection.setter def onnx_nodes_collection(self, arg): - """ - Set the onnx nodes collection. - """ + """Set the onnx nodes collection.""" if isinstance(arg, OrderedDict): self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader else: @@ -108,11 +110,18 @@ class GlobalContext(metaclass=Singleton): @property def onnx_nodes_topo_index(self) -> dict: - "Return the onnx nodes and topological index." + """Return the onnx nodes and topological index.""" return self._onnx_nodes_topo_index @onnx_nodes_topo_index.setter def onnx_nodes_topo_index(self, index_list): + """ + Set the onnx nodes and topological index. + + Args: + index_list (list[tuple[int, str]]): a list of tuple contains the topological index and onnx node name. + + """ if not isinstance(index_list, list): raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") if not isinstance(index_list[0], tuple): @@ -122,10 +131,17 @@ class GlobalContext(metaclass=Singleton): @property def onnx_tensors_collection(self): + """Return the onnx tensors collection.""" return self.onnx_tensors_collection @onnx_tensors_collection.setter def onnx_tensors_collection(self, arg): + """ + Set the onnx tensors collection by OnnxDataLoader. + + Args: + arg (dict): The OnnxDataLoader generated tensors_dict. + """ if isinstance(arg, dict): self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader else: @@ -133,6 +149,12 @@ class GlobalContext(metaclass=Singleton): @property def latest_node_struct_count(self): + """ + Return the latest node struct count. + + Note: + The counter will increase by 1 to tracking the number of nodes added. + """ ret = self.node_struct_adder_counter self.node_struct_adder_counter += 1 return ret @@ -184,18 +206,29 @@ class GlobalContext(metaclass=Singleton): self.module_customized_name[pattern_id] = customized_name def get_node_fragment(self, identifier): + """Return the node fragment by identifier.""" return self.node_fragments.get(identifier) def add_code_fragment(self, identifier, frag): + """Add the node fragment by identifier.""" self.node_fragments[identifier] = frag def get_module_fragment(self, identifier): + """Return the module fragment by identifier.""" return self.module_fragments.get(identifier) def add_module_fragment(self, identifier, frag): + """Add the module fragment by identifier.""" self.module_fragments[identifier] = frag def add_module_struct(self, pattern_id, module_struct): + """ + Add module struct by its pattern_id. + + Args: + pattern_id (int): The pattern which represents the structure of the module. + module_struct (ModuleStruct): The ModuleStruct instance. + """ if self.module_structs.get(pattern_id) is None: self.module_structs[pattern_id] = [module_struct] else: diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index b946e832..fd632ccf 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -135,3 +135,18 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): return False return True +def get_dict_key_by_value(val, dic): + """ + Return the first appeared key of a dictionay by given value. + + Args: + val (Any): Value of the key. + dic (dict): Dictionary to be checked. + + Returns: + Any, key of the given value. + """ + for d_key, d_val in dic.items(): + if d_val == val: + return d_key + return None diff --git a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py new file mode 100644 index 00000000..483a93b1 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generator module.""" +__all__ = ["batch_add_nodes"] + +import re +import copy + +from .generator import Generator, CodeStruct +from ..common.code_fragment import CodeFragment + + +def _tf_model_node_name_reformat(node, node_name): + """ + Rename the node name by combining scope name and its original name. + + Args: + node (OnnxGraphNode): OnnxGraphNode instance. + node_name (str): node name saved in Graph. + + Returns: + str, re-formatted node name. + """ + scope_name = node.scope_name + new_name = None + regex = r"(?P.+/)(?P\w+)" + match = re.match(regex, scope_name) + parent = match.group("parent") + node_name = '$' + node_name.replace('/', '::') + '$' + + if scope_name: + new_name = parent + node_name + return new_name + return node_name + + +def batch_add_nodes(graph_obj, mapper) -> Generator: + """ + Add nodes to Generator in batch mode. + + Args: + graph_obj (Graph): Graph obj. + mapper (Mapper): Mapper of third party framework and MindSpore. + + """ + generator_inst = Generator() + for node_name in graph_obj.nodes_in_topological_order: + node_inst = graph_obj.get_node(node_name) + node_input = graph_obj.get_input_shape(node_name) + node_output = graph_obj.get_output_shape(node_name) + if not node_input: + raise ValueError("Unable to get the node's inputs from Graph object.") + node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) + node_name = node_name_with_scope + + node_inst.add_input_and_output_shape(node_input, node_output) + op_name, params, settings, weights = _convert_params(node_inst, mapper) + generator_inst.add_node( + node_name, + node_instance=node_inst, + node_fragment=CodeFragment(op_name, params, + settings, + node_inst.input_shape, + node_inst.output_shape, + weights) + ) + return generator_inst + + +def _convert_params(node, mapper): + """ + Call mapper to convert node's params from ONNX to MindSpore. + + Args: + node (GraphNode): Our defined GraphNode instance. + mapper (Mapper): The mapper instance which indicating conversion method. + + Returns: + str, op name in MindSpore + dict, MindSpore parameters + dict, MindSpore settings + dict, weights of the node + """ + params = copy.deepcopy(node.node_params) + params.update({"input_shape": node.input_shape, + "output_shape": node.output_shape}) + + op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, + params=params, + weights=node.weight) + if "input_shape" in ms_params: + ms_params.pop("input_shape") + if "output_shape" in ms_params: + ms_params.pop("output_shape") + + if op_in_ms: + return op_in_ms, ms_params, ms_settings, weights + + return node.op_name, node.node_params, dict(), dict() diff --git a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py new file mode 100644 index 00000000..3219527b --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py @@ -0,0 +1,248 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define arguments translation related operations for params name changing.""" +import copy + + +class ArgsTranslation: + """Define a universal arguments translation manager.""" + + def __init__(self, original_actual_args: dict, var_name: str, translated_args: list): + """ + Init the ArgsTranslation. + + Args: + original_actual_args (dict): The full original args from fragments. + var_name (str): The var name for current Node / Module. + translated_args (list): The list of args need to translate to formal args. + """ + if not var_name: + raise ValueError("Initialize ArgsTranslation requires the var_name.") + + self.var_name = var_name + self.actual_args = dict() # e.g. key is 'num_features', value is 2048 + self.formal_args = dict() # e.g. key is 'num_features', value is 'var_name_num_features'} + self.formal_args_values = dict() # e.g. key 'var_name_num_features', value 2048. Value use for up-level + self.actual_args_backup = dict() # backup actual args before translation + + self.actual_args_to_str_list = list() + self.formal_args_to_str_list = list() + self.formal_args_values_to_str_list = list() + self.actual_args_backup_to_str_list = list() + + if all([original_actual_args, translated_args]): + # MUST ensure only one var_name in a scope. + for arg_name, arg_value in original_actual_args.items(): + if arg_name in translated_args: + formal_arg_name = '_'.join([var_name, arg_name]) + self.formal_args[arg_name] = formal_arg_name + self.formal_args_values[formal_arg_name] = arg_value + else: + self.actual_args[arg_name] = arg_value + + self.make_str() + + @staticmethod + def dict_data_to_args_str_list(any_dict): + """ + Output a list of string of dict data by "key=value" format. + + Args: + any_dict (dict): Any dictionary + + Returns: + list, the list of strings showing dictionary as "key=value" format. + """ + ret = [] + for key, val in any_dict.items(): + ret.append('='.join([key, str(val)])) + return ret + + def make_str(self): + """Make string used in code generation.""" + self.actual_args_to_str_list = list() + self.formal_args_to_str_list = list() + self.formal_args_values_to_str_list = list() + self.actual_args_backup_to_str_list = list() + + if self.actual_args: + self.actual_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args) + + if self.formal_args: + self.formal_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args) + + if self.formal_args_values: + self.formal_args_values_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args_values) + + if self.actual_args_backup: + self.actual_args_backup_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args_backup) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "var_name": self.var_name, + "actual_args": self.actual_args, + "actual_bak": self.actual_args_backup, + "formal_args": self.formal_args, + "formal_val ": self.formal_args_values + }) + + def set_actual_args_backup(self): + """Backup the actual args before translating to formal.""" + self.actual_args_backup = copy.deepcopy(self.actual_args) + + def deepcopy(self): + """Return a deepcopy of self.""" + return copy.deepcopy(self) + + def make_actual_arg_to_formal(self, actual_arg_name): + """ + Make the actual arg to a formal arg. + + Args: + actual_arg_name (str): The name of the actual arg to be formal. + """ + val = self.actual_args.get(actual_arg_name) + if val is None: + raise ValueError("Unable to convert the actual arg to formal due to missing arg.") + formal_arg_name = ('_').join([self.var_name, actual_arg_name]) + self.actual_args.pop(actual_arg_name) + self.formal_args[actual_arg_name] = formal_arg_name + self.formal_args_values[formal_arg_name] = val + self.make_str() + + def _update_dict_for_upper_level(self, d, upper_level_var_name): + """Add upper level var name to key name of selected dictionary.""" + new_d = dict() + for arg_name, val in d.items(): + new_arg_name = '_'.join([upper_level_var_name, arg_name]) # e.g. conv2d_0_in_channels_Module_3_0 + new_d[new_arg_name] = val + return new_d + + def escalate_to_upper_level(self, upper_level_var_name): + """ + Escalate this args translator for upper level module use. + + Note: + You MUST deepcopy this translator first to avoid editing values in the original translator. + """ + # update all args name by adding upper_level_var_name. + tmp_actual_args = self._update_dict_for_upper_level(self.actual_args, upper_level_var_name) + tmp_formal_args = self._update_dict_for_upper_level(self.formal_args, upper_level_var_name) + tmp_formal_args_values = self._update_dict_for_upper_level(self.formal_args_values, upper_level_var_name) + + self.actual_args = tmp_actual_args + self.formal_args = tmp_formal_args + self.formal_args_values = tmp_formal_args_values + + self.make_str() + + def make_formal_args_back_to_actual(self, formal_arg): + """ + Move the formal arg back to actual arg. + + Note: + This does not reset the formal arg name back, + Only used for module init statement. + + Args: + formal_arg (str): formal argument name. + """ + if isinstance(formal_arg, str): + val = self.formal_args_values.pop(formal_arg) + self.actual_args[formal_arg] = val + if isinstance(formal_arg, list): + for arg in formal_arg: + val = self.formal_args_values.pop(arg) + self.actual_args[formal_arg] = val + + self.make_str() + + def take_formal_args_from_args_translator(self, args_translator, escalate_sub=False): + """ + Add submodule's or node's args translator to this translator. + + Args: + args_translator (ArgsTranslation): submodule's or node's args translator. + """ + if escalate_sub: + sub_args_translator = args_translator.deepcopy() + sub_args_translator.escalate_to_upper_level(self.var_name) + else: + sub_args_translator = args_translator + + original_actual_args = sub_args_translator.formal_args_values + self.actual_args.update(original_actual_args) + self.make_str() + + def take_formal_args_from_nodes_and_submodules(self, args_translators: list, escalate_sub=False): + """ + Take all formal args from nodes and submodules from passed in args_translators. + + Args: + args_translators (ArgsTranslation): A list of ArgsTranslation instances. + escalate_sub (Bool): should escalate all formal args. Default: False + """ + for arg_t in args_translators: + self.take_formal_args_from_args_translator(arg_t, escalate_sub=escalate_sub) + + +class ArgsTranslationHelper: + """Define operations related to ArgsTranslation instances.""" + @staticmethod + def find_formal_args_in_modules(args_translators): + """ + Find formal args among multiple args translators. + + Args: + args_translators(list[ArgsTranslation]): a list of args translator to be checked. + + Returns: + list, name of args to be formal. + """ + if len(args_translators) < 2: + # only one args_translator provided, no formal args. + return None + ret = [] + base_args_t = args_translators[0] + for arg_name, arg_val in base_args_t.actual_args.items(): + for args_t in args_translators[1:]: + val = args_t.actual_args.get(arg_name) + + if val is None: + raise ValueError("Unable to find the given args as the args translator is not consistent.") + if val != arg_val: # val not equal + ret.append(arg_name) + break + return ret + + @staticmethod + def change_args_to_formal_for_all_translators(args_name, args_translators): + """ + Change args to formal for all translators provided. + + Args: + args_name (str): The name of args to be changing. + args_translators (ArgsTranslation): The args to be changed in args translators. + """ + if isinstance(args_name, str): + args_name = [args_name] + if isinstance(args_translators, ArgsTranslation): + args_translators = [args_translators] + + for arg in args_name: + for args_t in args_translators: + args_t.set_actual_args_backup() + args_t.make_actual_arg_to_formal(arg) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py new file mode 100644 index 00000000..a9ba9d6e --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -0,0 +1,630 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Main Generator module.""" +import copy +from collections import OrderedDict + +from .scope_utils import Scope +from .node_struct import NodeStruct +from .module_struct import ModuleStruct +from .args_translator import ArgsTranslationHelper +from ..common.global_context import GlobalContext +from ..hierarchical_tree.name_mgr import GlobalVarNameMgr +from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT + + +class Singleton(type): + """Metaclass to make the generator to be single instance.""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class CodeStruct: + """ + Define the Code template for each module generated in the final output. + Each module has only one CodeStruct to its pattern. + """ + GLOBAL_CONTEXT = GlobalContext() + NOT_IN_SCOPE_OPT = dict() + + def __init__(self, struct, repeated_submodules=None): + """Initialize the CodeStruct.""" + self.output_order = None # output order + self.input = None # opt_var_name for prev. node + self.extra_input = list() # extra_input(s) at construct method args + self.output = None # opt_var_name for next node + self.extra_output = list() # extra_output(s) + self.extra_comment = None # comments for this code line / block. + self.code_line_list = list() # list of code line, a item is a line. + self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module + + self.formal_args_collections = None + + if isinstance(struct, NodeStruct): + self.output_order = struct.topo_idx + if isinstance(struct, ModuleStruct): + self.output_order = struct.head_nd_struct_index + self._generate_from_module_struct(struct, repeated_submodules) + + def _add_line(self, s): + """Add line of code.""" + self.code_line_list.append(s) + + @property + def new_line(self): + """Return last generated line.""" + try: + return self.code_line_list[-1] + except IndexError: + return "" + + @new_line.setter + def new_line(self, s): + """Make a new line.""" + self._add_line(s) + + def _generate_from_module_struct(self, md_struct, repeated_submodules): + """ + Generate the code of current Module Struct, collecting data from submodules. + + Args: + md_struct (ModuleStruct): The ModuleStruct which generates codes. + repeated_submodules (dict): The dict contains all submodules which use repeatedly. + Can get this dict from generator. + """ + # Define tmp var for code generation. + opt_var_name_records = dict() # now only support multiple outputs within same scope. + return_value_records = dict() # save returned values for successor nodes/modules use. + # Define Module header code line below + if md_struct.pattern_id != -1: + class_name = f"Module{md_struct.pattern_id}" + else: + class_name = "Model" + # define a class declaration + self.new_line = f"class {class_name}(nn.Cell):" + + # Get all formal args from nodes + module_def_args = ['self'] + if md_struct.args_translator.actual_args: + for actual in md_struct.args_translator.actual_args.keys(): + module_def_args.append(actual) + if md_struct.args_translator.formal_args: + for formal in md_struct.args_translator.formal_args.keys(): + module_def_args.append(formal) + + # Collect extra inputs and outputs + + # For code line in init & construct blocks + init_lines = list() + cons_lines = list() + for (_, struct) in md_struct.get_generate_order(): + if isinstance(struct, NodeStruct): # Generate code line for Node. + code_line_init = struct.code_line_in_init() + code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") + cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") + # add extra tensor + if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: + code_extra_tensor = struct.add_extra_tensor() + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") + + # record opt_var_name for succ nodes input in same scope. + target_onnx_name = struct.graph_node_ref.successor_nodes + for name in target_onnx_name: + if opt_var_name_records.get(name): + opt_var_name_records.get(name).append(code_line_construct[0]) + else: + opt_var_name_records[name] = [code_line_construct[0]] + + if struct.successor_nodes_names_external: + for ret_user in struct.successor_nodes_names_external: + if return_value_records.get(ret_user) is not None: + return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) + else: + return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] + + elif isinstance(struct, ModuleStruct): + # check if this instance generated CodeStruct + if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: + CodeStruct(struct, repeated_submodules) + + code_line_init = struct.code_line_in_init() + code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs) + init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") + cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") + + # record opt_var_name for succ nodes input in same scope. + target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes + for name in target_onnx_name: + if opt_var_name_records.get(name): + opt_var_name_records.get(name).append(code_line_construct[0]) + else: + opt_var_name_records[name] = [code_line_construct[0]] + + # record submodule's local return map for following nodes / submodules use + if struct.external_successor_local_returns_map: + for ret_user, _ in struct.external_successor_local_returns_map.items(): + if return_value_records.get(ret_user) is not None: + # mulitple returns of a node may need modifiy the index. + return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) + else: + return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] + else: + raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") + + # define header of init block + self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" + self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" + # add init code lines to code line list. + self.code_line_list += init_lines + self.new_line = f"{NEW_LINE * 2}" + + # define header of construct block + inputs = ['self'] + list(md_struct.construct_header_x.keys()) + self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" + # add construct code lines to code line list. + self.code_line_list += cons_lines + # define returns + returns = [] + if md_struct.external_successor_local_returns_map: + ret = list(md_struct.external_successor_local_returns_map.values()) + for r in ret: + if isinstance(r, tuple): # results return with index nth output + returns.append(r[0]) + else: + returns.append(r) + returns = list(set(returns)) + else: + returns = [code_line_construct[0]] + self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" + self.new_line = f"{NEW_LINE * 2}" + self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self + + +class Generator(metaclass=Singleton): + """The generator controls all routines of code generation.""" + + def __init__(self): + """Init the generator.""" + # define basic attributes + self.framework = None + + # define MUST have params + self._node_struct_collections = OrderedDict() + self._module_struct_collections = OrderedDict() + self._module_depth_max = 0 + self._module_depth_min = 0 + + # define intermediate var. during conversion + self._module_map = OrderedDict() + self._global_context = GlobalContext() + self._global_context.node_struct_collections = self._node_struct_collections + self._repeated_submodules = set() + + def _form_bottom_submodule(self): + """Form the basic submodules, which only contains nodes.""" + # Form module map + curr_scope_path = None + nd_struct_list_in_submodule = [] + for nd_struct in self.node_structs.values(): + idx = nd_struct.topo_idx + if curr_scope_path is None: + curr_scope_path = nd_struct.scope.path + nd_struct_list_in_submodule.append((idx, nd_struct)) + elif curr_scope_path == nd_struct.scope.path: + nd_struct_list_in_submodule.append((idx, nd_struct)) + else: # curr_scope_path changed + # save this submodule + if self._module_map.get(str(curr_scope_path)) is not None: + self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule + else: + self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule + + # create a new one + curr_scope_path = nd_struct.scope.path + nd_struct_list_in_submodule = [(idx, nd_struct)] + + # save last submodule + if self._module_map.get(str(curr_scope_path)) is not None: + self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule + else: + self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule + + # Form bottom modules' ModuleStruct + for scope_path_str, nd_struct_list in self._module_map.items(): + self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list) + + def _list_repeated_submodules(self) -> OrderedDict: + """ + Return the repeated submodules by its depth and num. + For example, "Model/Module3_3" will return {1:(3)} + + Return: + OrderedDict, a dict contains collections of repeated submodules. + """ + ret = OrderedDict() + for depth_control in range(self._module_depth_max, 0, -1): + repeated_submodules_at_this_depth = set() + for scope_path in self._module_map.keys(): + path = Scope.path_str_to_list(scope_path) + if len(path) < depth_control: + continue + else: # depth control within path length + module_num = path[depth_control - 1][0] + repeated_submodules_at_this_depth.add(module_num) + ret[depth_control] = repeated_submodules_at_this_depth + + self._repeated_submodules = ret + return ret + + def _compare_with_base_parameters(self, nd_struct_list): + """ + Compare the parameter to check if it should be a formal args. + + Args: + nd_struct_list (list): A list of NodeStructs which contains + all same nodes in repeated submodules. + + Return: + set, a set of all formal args in this node. + """ + + formal_args = set() + if len(nd_struct_list) < 2: + return formal_args + (_, base_nd_struct) = nd_struct_list[0] + for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param + for (_, nd_struct) in nd_struct_list[1:]: + compared_value = nd_struct.fragment.actual_args.get(base_parameter) + if compared_value == base_value: + continue + else: + formal_args.add(base_parameter) + break + + return formal_args + + def _list_formal_parameters_in_a_module(self, module_filter_return): + """ + Find all formal args / params from nodes in a module. + + Args: + module_filter_return (dict): The filtered results from the module_map_filter. + + Return: + list, a list of sets or None indicates all formal args of each node in the module in order. + """ + formal_params_list = list() + transposed = [list(e) for e in zip(*module_filter_return)] + for operation in transposed: + formal_parameters = self._compare_with_base_parameters(operation) + if formal_parameters: + formal_params_list.append(formal_parameters) + else: + formal_params_list.append(None) + return formal_params_list + + def _list_formal_parameters(self, repeated_submodules) -> dict: + """ + Return a list of formal parameters in each submodule. + + Args: + repeated_submodules (dict): A dict which contains repeated submodules, + acquire this dict from _list_repeated_submodules() + + Return: + OrderedDict, a dict with each submodule's formal args. + + Example: + A return for ResNet50 could be: + + {0: # submoodule 0 + [set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal + set('num_features'), # args of the second node to be set as formal + None, # args of third node to be set as formal, which does not have + set('in_channels', 'out_channels'), + set('num_features'), + None + ]}, + {3: # submodule 3 + [...], + {5: # submodule 5 + []} # empty returns means no nodes or it's a parent module of submodules. + } + """ + formal_args_in_each_submodule = OrderedDict() + checked_module = set() + # filter module_map by submodule_num (without depth) + for _, module_nums in repeated_submodules.items(): + for module_num in module_nums: + if module_num in checked_module: # module already checked + continue + else: + checked_module.add(module_num) + map_filtered = self.module_map_filter(module_num=module_num) + formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) + formal_args_in_each_submodule[module_num] = formal_args_in_this_module + return formal_args_in_each_submodule + + def _add_submodule_to_parent(self): + """ + Recursively add all submodule to its parent module until Main module. + + Note: + This function deepcopy the first node of the submodule, and reset its params as parent module. + """ + depth = self._module_depth_max + while depth > 0: + for (scope_path_str, md_struct) in self.module_structs.copy().items(): + if scope_path_str == '[]': + continue # is main module, skip + if md_struct.scope_depth != depth: + continue # skip all submodules not at current depth + md_struct_scope = copy.deepcopy(md_struct.identifier) + md_struct_scope.pop() + parent_scope = md_struct_scope + # 1. check if this module has parent module + parent_md_struct = self.module_structs.get(str(parent_scope)) + if parent_md_struct is not None: + # 1A. has parent, directly add md_struct to its parent ModuleStruct. + parent_md_struct.add_submodule(md_struct) + self.module_structs[str(parent_scope)] = parent_md_struct + else: + # 1B. not has parent, generate a new ModuleStruct + parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module + # rewrite parent md struct + parent_md_struct.reset_as_parent() + parent_md_struct.add_submodule(md_struct) + self.module_structs[str(parent_scope)] = parent_md_struct + sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections + self._global_context.add_module_struct(sub.pattern_id, sub) + depth -= 1 + + def _recursive_form_module(self): + """Main routine in generator to build modules from bottom to top.""" + # 1. List repeated submodules + repeated_submodules = self._list_repeated_submodules() + # 2. List reused parameters + formal_parameters = self._list_formal_parameters(repeated_submodules) + # 3. Build base subdmodules and set in/ext params translation + for module_struct in self.module_structs.values(): + if module_struct.pattern_id == -1: # is main module + continue + formal_args = formal_parameters.get(module_struct.pattern_id) + module_struct.update_args_translation_list(formal_args) + + # 4. Form parent modules + md_collection_len = len(self.module_structs.keys()) + len_changes = True + while len_changes: + self._add_submodule_to_parent() + new_len = len(self.module_structs.keys()) + if md_collection_len != new_len: + md_collection_len = new_len + else: + len_changes = False + + # 5. Update all translated args from module map + self._update_all_modules_args_translator() + + # 6. Update all nodes and moudles input/output + self.module_structs.get('[]').allocate_construct_header_x() + self.module_structs.get('[]').collect_returns() + + def _update_all_modules_args_translator(self): + """Update all modules' args translators.""" + done_submodule = set() + for depth in range(self._module_depth_max, 0, -1): + # check modules from bottom to top + repeated_submodules = copy.deepcopy(self._repeated_submodules) + repeated_modules = repeated_submodules.get(depth) + if depth is None: + continue + for pattern_id in repeated_modules: + if pattern_id in done_submodule: + continue + # get all md_structs by same pattern + md_list = self._global_context.module_structs.get(pattern_id) + self._take_formal_args_from_updated_submodules(md_list) + args_translators = self.get_args_translator_from_module_structs_list(md_list) + formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators) + changed_args_translators = self.get_args_translator_from_module_structs_list( + md_list, exclude_root_son=True) + ArgsTranslationHelper.change_args_to_formal_for_all_translators( + formal_args_list, changed_args_translators) + done_submodule.add(pattern_id) + + def _take_formal_args_from_updated_submodules(self, md_list): + """ + Take formal args from provided modules' nodes and submodules. + + Args: + md_list (list): A list of ModuleStruct. + """ + if isinstance(md_list, ModuleStruct): + md_list = [md_list] + + for md in md_list: + md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators()) + + def _update_module_depth_max(self, nd_struct: NodeStruct): + """ + Update the Generator attribute module_depth_max, which is the maximum depth in the Model. + + Args: + nd_struct (NodeStruct): NodeStruct to be checked its depth. + """ + depth = nd_struct.scope.depth + if isinstance(depth, int): + if depth > self._module_depth_max: + self._module_depth_max = depth + else: + raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") + + def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None): + """ + Add Node information to the generator. + + Args: + node_identifier (str): The unique identifier for the node passed in. + node_instance (GraphNode): The GraphNode instance of each node. + node_fragment (NodeFragment): The NodeFragment instance of this node passed in. + mapper_dict (dict): The dict contains converted attributes from mapper. + """ + + if node_identifier is None: + raise ValueError("Node Identifier should not be None.") + self._global_context.node_fragments[node_identifier] = node_fragment + args = [] + if node_instance is not None: + args.append(node_instance) + if mapper_dict is not None: + args.append(mapper_dict) + if node_fragment is not None: + args.append(node_fragment) + + nd_struct = self.node_structs.get(node_identifier) + if nd_struct: # NodeStruct already exists + nd_struct.update(args) + else: # create new Node Struct + nd_struct = NodeStruct(args) + nd_struct.identifier = node_identifier + self._update_module_depth_max(nd_struct) + self.node_structs[node_identifier] = nd_struct + + @property + def node_structs(self): + """Return all NodeStructs in this model.""" + return self._node_struct_collections + + @property + def module_structs(self): + """Return all ModuleStructs in this model.""" + return self._module_struct_collections + + def generate(self): + """ + Generate the final script file. + + Returns: + list, a list of each line in script file. + """ + self._form_bottom_submodule() + self._recursive_form_module() + code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) + return code.code_line_list + + def get_node_struct(self, node_identifier): + """ + Get specific NodeStruct by node_identifier. + + Args: + node_identifier (str): The node unique identifier. + + Return: + NodeStruct, the node's NodeStruct. + """ + return self._node_struct_collections.get(node_identifier, None) + + def get_module_struct(self, module_identifier): + """ + Get specific ModuleStruct by module_identifier. + + Args: + module_identifier (str): The module unique identifier. + + Return: + ModuleStruct, the node's ModuleStruct. + """ + return self._module_struct_collections.get(module_identifier, None) + + def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): + """ + Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. + + Args: + pattern_id (int): The pattern id the returned ModuleSturct is. + under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. + + Returns: + list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. + """ + if not pattern_id: + raise ValueError("pattern_id is necessary to get the module struct.") + if not under_parent_pattern_id: + raise ValueError("under_parent_pattern_id is necessary to get the module struct.") + ret = [] + md_list = self._global_context.module_structs.get(pattern_id) + for md in md_list: + if md.parent_id == under_parent_pattern_id: + ret.append(md) + return ret + + def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): + """ + Return a list of args translators which belongs to given module structs. + + Args: + md_list (list): A list of ModuleStruct. + exclude_root_son (Bool): If the returned result should include args translator belongs to + modules under the Main module. + + Returns: + list, a list of args translators which belongs to given module structs. + """ + ret = [] + for md in md_list: + if exclude_root_son and md.parent_id == -1: + continue + if md.args_translator is not None: + ret.append(md.args_translator) + + return ret + + def module_map_filter(self, depth=None, module_num=None, uid=None): + """ + Filter the module map by given conditions. + + Args: + depth (int): Scope depth. + module_num (int): The submodule number. + uid (int): The unique identifier of a submodule. + + Return: + list, list of NodeStruct list of each submodule. + """ + ret = list() + for scope_path, nd_struct_list in self._module_map.items(): + path = Scope.path_str_to_list(scope_path) + if not path: # skip main + continue + + # if depth not equals to the indicated depth, skip + if depth is not None and len(path) != depth: + continue + + scope_at_depth = path[-1] + (m_num, m_uid) = scope_at_depth + if uid is not None: + if m_num == module_num and m_uid == uid: + ret.append(nd_struct_list) + else: + if m_num == module_num: + ret.append(nd_struct_list) + return ret diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py new file mode 100644 index 00000000..ca0a0a90 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -0,0 +1,710 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define a struct for module converted and save all required information here.""" + +from collections import OrderedDict + +from .node_struct import NodeStruct +from .scope_utils import Scope +from ..common.utils import get_dict_key_by_value +from .args_translator import ArgsTranslation +from ..common.code_fragment import ModuleFragment +from ..common.global_context import GlobalContext +from ..hierarchical_tree.name_mgr import LocalVarNameMgr + + +class ModuleStruct: + """ + Define a module struct which stores all info. to generate statement. + + Args: + args (list): A list of node structs. + """ + GLOBAL_CONTEXT_MGR = GlobalContext() + + def __init__(self, nd_struct_list): + """Init. a module by NodeStructs.""" + self.pattern_id = -1 # pattern num, -1 as Main module + self.pattern_uid = -1 # unique module id for this pattern + self.parent_id = None # parent's pattern num + self.parent_uid = None # parent's pattern module unique id + self.initialized = False + self.identifier = None + self.module_name = None + self.scope_depth = None + self.head_nd_struct = None + self.head_nd_struct_index = None + self.tail_nd_struct = None + self.tail_nd_struct_index = None + self._node_structs = list() + self._module_structs = list() + + self._fragment = None + self._args_translator = None + self._setting = None + self._parent_module_struct = None + # only store original formal args name, not global + self._nodes_structs_formal_args_list = list() + # only store translated (globalized) formal args + self._nodes_structs_formal_args_translated_list = list() + + # define other settings here + self._node_args_translation_list = list() + self._var_name_mgr = LocalVarNameMgr() + self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name + self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct + self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name + + # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) + self.outputs_collection = dict() + self.matched_inputs = list() # Matched inputs will can be directly used by code line generation + + # key is ext. succ node onnx name, value is local opt_var + self.external_successor_local_returns_map = OrderedDict() + + # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level + self.outputs_collection = dict() + + # start initialization + if not self.initialized: + self._init_module(nd_struct_list) + else: + self._update_module(nd_struct_list) + + # assign this module reference to node + for (_, nd_struct) in nd_struct_list: + nd_struct.parent_module_struct = self + + def reset_as_parent(self): + """ + Reset all attributes and filled as a parent module of this module. + + Note: + This function must be called only after a deepcopy of this instance! + """ + self.identifier.pop() + self.scope_depth = self.scope_depth - 1 + self._set_pattern_id() + self._find_parent_module() + self.module_name = Scope.scope_to_module_name(self.identifier) + self._node_structs = list() + self._module_structs = list() + self._fragment = None + self._args_translator = None + self.init_args_translator() + self._setting = None + self._parent_module_struct = None + self._nodes_structs_formal_args_list = list() + + self._node_args_translation_list = list() + + def _set_pattern_id(self): + """Set pattern id which matches the module fragment pattern.""" + if not self.initialized: + return + if self.scope_depth < 1: + self.pattern_id = -1 + self.pattern_uid = -1 + return + self.pattern_id = self.identifier[-1][0] + self.pattern_uid = self.identifier[-1][1] + + def _init_module(self, nd_struct_list): + """Init this ModuleStruct by a list of Nodes.""" + (nd_topo_idx, nd_struct) = nd_struct_list[0] + self.identifier = nd_struct.scope.path + self.module_name = nd_struct.scope.to_str + self.scope_depth = nd_struct.scope.depth + self.head_nd_struct = nd_struct + self.head_nd_struct_index = nd_topo_idx + self.tail_nd_struct = nd_struct_list[-1][1] + self.tail_nd_struct_index = nd_struct_list[-1][0] + self._node_structs = nd_struct_list + self.initialized = True + self._set_pattern_id() + self._find_parent_module() + self.init_args_translator() + + def _update_module(self, nd_struct_list): + """Update the ModuleStruct attributes from a list of Nodes.""" + (nd_topo_idx_head, nd_struct_head) = nd_struct_list[0] + (nd_topo_idx_tail, nd_struct_tail) = nd_struct_list[-1] + if self.identifier != nd_struct_head.scope.path: + raise ValueError("Unable to update this module struct {} due to different identifier {}".format( + self.identifier, nd_struct_head.scope.path)) + if nd_topo_idx_head < self.head_nd_struct_index: + self.head_nd_struct_index = nd_topo_idx_head + self.head_nd_struct = nd_struct_head + if nd_topo_idx_tail > self.tail_nd_struct_index: + self.tail_nd_struct_index = nd_topo_idx_tail + self.tail_nd_struct = nd_struct_tail + self._node_structs += nd_struct_list + + def _find_parent_module(self): + """Set the parent's module pattern and uid.""" + if not self.initialized: + return + if self.scope_depth == 0: # is Main Module + pass + elif self.scope_depth == 1: # parent pattern is Main module + self.parent_id = -1 + self.parent_uid = -1 + else: # this is a submodule in a module + (self.parent_id, self.parent_uid) = Scope.get_parent_module_num_and_uid( + self.identifier) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "identifier": self.identifier, + "parent": (self.parent_id, self.parent_uid), + "name": self.module_name, + "pattern": self.pattern_id, + "scope_depth": self.scope_depth, + "nd_idx_range": "{} -> {}".format(self.head_nd_struct_index, self.tail_nd_struct_index), + "initialized": self.initialized + }) + + def init_module_fragment(self): + """Init the module fragment.""" + if not self.initialized: + return + # check if fragment exists in global context + op = "Module{}".format(self.pattern_id) + if op == "Module-1": # reset as Main Model's op name + op = "Model" + frag = GlobalContext().get_module_fragment(op) + if frag is not None: # use exists fragment + self._fragment = frag + else: + frag = ModuleFragment(operation=op, + actual_args=None, + input_shape=None, + output_shape=None, + settings=None) + self._fragment = frag + # set fragment pattern + self._fragment.pattern = self._node_structs + GlobalContext().add_module_fragment(op, frag) + + def init_args_translator(self): + """Initialize the Args Translator for the module.""" + var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) + self._args_translator = ArgsTranslation(None, var_name, None) + + def update_module_fragment(self): + """Update this module's fragment.""" + if self._fragment is None: + return + + # update input output shape + self._fragment.input_shape = self.head_nd_struct.fragment.input_shape + self._fragment.output_shape = self.tail_nd_struct.fragment.output_shape + + # update formal args + self._fragment.formal_args.update(self._args_translator.formal_args) + self._fragment.formal_args_value.update(self._args_translator.formal_args_values) + # update actual args + self._fragment.actual_args.update(self._args_translator.actual_args) + # update others.. + + def add_submodule(self, md_structs): + """ + Add another module struct(s) to this ModuleStruct. + + Args: + md_structs ([ModuleStruct, list]): a (list) ModuleStruct to be added in this ModuleStruct. + """ + tail_md = md_structs + if isinstance(md_structs, ModuleStruct): + md_structs.args_translator.take_formal_args_from_nodes_and_submodules(md_structs.get_all_sub_translators()) + self._module_structs.append(md_structs) + md_structs.parent_module_struct = self + elif isinstance(md_structs, list): + for md_s in md_structs: + md_s.args_translator.take_formal_args_from_nodes_and_submodules(md_s.get_all_sub_translators()) + md_s.parent_module_struct = self + self._module_structs += md_structs + tail_md = md_structs[-1] + else: + raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( + type(md_structs))) + # update tail node and index + if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: + self.tail_nd_struct = tail_md.tail_nd_struct + self.tail_nd_struct_index = tail_md.tail_nd_struct_index + + def _update_formal_args_for_all_nd_structs(self): + """ + Init nodes' args translator and find formal args. + And collect nodes' formal args. + """ + if len(self._node_args_translation_list) != len(self._node_structs): + raise ValueError( + "ModuleStruct cannot update nodes' formal args due to length inconsistent.") + for idx, (_, nd_struct) in enumerate(self._node_structs): + formal_arg_of_this_node = self._node_args_translation_list[idx] + # update var_name to ensure all node names' are unique in a module. + nd_struct.update_var_name(idx) + nd_struct.init_args_translator(formal_arg_of_this_node) + if nd_struct.args_translator is not None: + self._nodes_structs_formal_args_list.append( + nd_struct.args_translator.formal_args_values) + else: + self._nodes_structs_formal_args_list.append(None) + + def update_args_translation_list(self, formal_args): + """ + Receive a list of args name to be changed to formal args, and change them. + + Args: + formal_args (list[str]): a list of args name to be changed to formal args. + """ + self._node_args_translation_list = formal_args + self._update_formal_args_for_all_nd_structs() + + def get_all_sub_translators(self): + """ + Return a list of args_translators of submodules / nodes affiliated to this module. + + Note: + The order of returned list is followed by the actual topological order. + + Returns: + list, a list of args_translators. + """ + ret = [] + for (_, struct) in self.get_generate_order(): + if struct.args_translator is not None: + ret.append(struct.args_translator) + return ret + + def get_generate_order(self): + """ + Return the order of generated code by index. + + Return: + list, a list of reference of node_struct or module_struct. + """ + ret = list() + if not self._module_structs: + return self._node_structs + # Generate a list of tuple (idx, module_structs) + for md_struct in self._module_structs: + ret.append((md_struct.head_nd_struct_index, md_struct)) + if self.node_structs: + ret += self.node_structs + ret.sort(key=lambda x: x[0]) + return ret + + def code_line_in_init(self): + """ + Initialization line of code in module init block. + + Args: + override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. + """ + left = "self.{}".format(self.ms_var_name) + args_list = list() + # Load args in init statement. + if self._args_translator is not None: # from args_translator + if self._args_translator.actual_args: # load actual args + args_list += self._args_translator.actual_args_to_str_list + elif self._args_translator.actual_args_backup and self.parent_id == -1: + # For modules repeated in multiple levels, the module under main model should + # not use formal args as it is unnecessary -> load from actual args backup + args_list += self._args_translator.actual_args_backup_to_str_list + args_list += self._args_translator.formal_args_to_str_list # load from formal args + else: + args_list += self._fragment.actual_args + right = f"{self.class_name}({', '.join(args_list)})" + return (left, right) + + def code_line_in_construct(self, inputs=None): + """Construct line of code in module construct block.""" + # check number of outputs this module has + opt_var_name_in_module = list(self.external_successor_local_returns_map.values()) + num_output = len(set(opt_var_name_in_module)) + if num_output == 1: # single output + left = f"{self.ms_opt_var_name}" + else: + left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)] + + if inputs is None and self.matched_inputs: + inputs = self.matched_inputs + + if isinstance(inputs, str): + inputs = [inputs] + right = f"self.{self.ms_var_name}({', '.join(inputs)})" + return (left, right) + + @property + def node_structs(self): + """Return all node structs in this module.""" + return self._node_structs + + @property + def module_structs(self): + """Return all module structs in this module.""" + return self._module_structs + + @property + def parent_module_struct(self): + """Return this module's parent module struct.""" + return self._parent_module_struct + + @parent_module_struct.setter + def parent_module_struct(self, ref): + """Set this modu;e's parent module struct.""" + self._parent_module_struct = ref + + @property + def args_translator(self): + """Return the args translator.""" + return self._args_translator + + @property + def head_nd_struct_precursor_nodes_names(self) -> list: + """Return head node's precursor nodes names.""" + return self.head_nd_struct.precursor_nodes_names + + @property + def head_nd_struct_precursor_nodes_structs(self) -> list: + """Return head node's precursor nodes structs.""" + return self.head_nd_struct.precursor_nodes_structs + + @property + def tail_nd_struct_successor_nodes_names(self) -> list: + """Return tail node's successor nodes names.""" + return self.tail_nd_struct.successor_nodes_names + + @property + def tail_nd_struct_successor_nodes_structs(self) -> list: + """Return tail node's successor nodes structs.""" + return self.tail_nd_struct.successor_nodes_structs + + @property + def onnx_names_from_nodes(self) -> list: + """Return all nodes onnx names in this module.""" + ret = [] + for (_, node) in self.node_structs: + ret.append(node.onnx_name) + return ret + + @property + def onnx_names_from_submodules(self) -> list: + """Return all nodes onnx names in submodules of this module.""" + ret = [] + for md_struct in self.module_structs: + ret += md_struct.onnx_names + return ret + + @property + def onnx_names(self) -> list: + """Return all nodes' onnx names which contained by this module.""" + return self.onnx_names_from_nodes + self.onnx_names_from_submodules + + @property + def external_precursor_nodes_names(self) -> list: + """Return all precursors nodes names not in this module.""" + ret = [] + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): + precursor_nodes_names = struct.precursor_nodes_names + + if isinstance(struct, ModuleStruct): + precursor_nodes_names = struct.external_precursor_nodes_names + + for p_name in precursor_nodes_names: + if p_name in self.onnx_names: + continue + ret.append(p_name) + return ret + + @property + def external_successor_nodes_names(self) -> list: + """Return all precursors nodes names not in this module.""" + ret = [] + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): + successor_nodes_names = struct.successor_nodes_names + + if isinstance(struct, ModuleStruct): + successor_nodes_names = struct.external_successor_nodes_names + + for s_name in successor_nodes_names: + if s_name in self.onnx_names: + continue + ret.append(s_name) + return ret + + @property + def class_name(self) -> str: + """Return the class name for generating code of this module.""" + if self.pattern_id == -1: + return "Model" + return "Module{}".format(self.pattern_id) + + @property + def ms_var_name(self) -> str: + """Return the variable name for generated code statement of this module.""" + if self.pattern_id == -1: + return "Model" + return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() + + @property + def ms_opt_var_name(self) -> str: + """Return the variable name for generated code statement of the output of this module.""" + return "{}_opt".format(self.ms_var_name).lower() + + # The following part will be resetting nodes' external inputs for supporting multi-in/out + # and should be called after generator.recursive_form_modules() + + def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): + """ + Mark the registered external inputs for code generation. + + Note: + This function to be called by its parent (ModuleStruct). + + Args: + header_x (str): The `x` in module construct header. + onnx_precursor_node_name (str): The original onnx node name. + """ + if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: + raise ValueError("The input from {} has already registered. Check this Module \ + {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) + self.inputs_in_construct_header[onnx_precursor_node_name] = header_x + + def allocate_construct_header_x(self, force_x=None): + """ + Allocate the x in construct header for each external input. + + Args: + force_x (str): Force the arg name to customized. + """ + local_x_name = 'x' + if force_x: # name of x indicated by external + local_x_name = force_x + + # set construct_header_x for current module + allocated = set() + for prec_name in self.external_precursor_nodes_names: + if prec_name in allocated: + continue + x_name_in_construct_header = self._var_name_mgr.get_name(local_x_name) + self.construct_header_x[x_name_in_construct_header] = prec_name + allocated.add(prec_name) + + # Assign these inputs to nodes and submodules + for _, struct in self.get_generate_order(): + if isinstance(struct, NodeStruct): # register node's ext input + self.reset_node_external_input_to_local(struct) + self.register_node_output_to_module(struct) + if isinstance(struct, ModuleStruct): # reg module's ext input + if not struct.construct_header_x: + struct.allocate_construct_header_x() + self.reset_submodule_external_input_to_local(struct) + self.register_submodule_output_to_module(struct) + + # remove parent module's ext. map if ext nodes in this module (no need return) + for user_name in self.external_successor_local_returns_map.copy().keys(): + if user_name in self.onnx_names: + self.external_successor_local_returns_map.pop(user_name) + + def _match_node_inputs(self, struct): + """Match node's inputs with its precursor nodes.""" + for output_provider in struct.precursor_nodes_names: + output_list = self.outputs_collection.get(output_provider) + if output_list is None: + # not in this module, check construct header + for (self_x_name, self_output_provider) in self.construct_header_x.items(): + if self_output_provider == output_provider: + struct.matched_inputs.append(self_x_name) + continue + for output in output_list: + (provider_succ, provider_closet_opt_var) = output + if provider_closet_opt_var in struct.matched_inputs: + continue # skip repeat + if provider_succ == struct.onnx_name: + struct.matched_inputs.append(provider_closet_opt_var) + + def _match_sub_modules_inputs(self): + """ + Match current module's submodules' inputs with corresponding outputs registered in current module. + + Description: + The function matches these inputs by the following steps: + 1. For each submodule in the current module, take submodule's construct header + 2. Check submodule's construct header element requires an input from current module's + construct header or outputs from other submodules. + 3. If from current module's construct header, assign corresponding x to the submodule. + If from other submodules, assign required submodule output name to the submodule. + """ + if not self.outputs_collection: + return # skip first node + for (_, struct) in self.get_generate_order(): + if isinstance(struct, NodeStruct): + self._match_node_inputs(struct) + continue # skip node + sub_construct_header = struct.construct_header_x + for (_, output_provider) in sub_construct_header.items(): + # check from outputs collection + output_list = self.outputs_collection.get(output_provider) + if output_list is None: + # not in this module, need from current module construct header + for (self_x_name, self_output_provider) in self.construct_header_x.items(): + if self_output_provider == output_provider: + struct.matched_inputs.append(self_x_name) + continue + for output in output_list: + (provider_succ, provider_closet_opt_var) = output + if provider_closet_opt_var in struct.matched_inputs: + continue # skip repeat + if provider_succ in struct.onnx_names: + struct.matched_inputs.append(provider_closet_opt_var) + + def _append_to_outputs_collection(self, provider_name, val): + """ + Helper function to add a nodes or submodules outputs to current module return statement. + + Args: + provider_name (str): The onnx name of the output provider. + val (list[tuple]): A list of tuple which contains + the output provider's successor name and its opt_var_name. + """ + exist_output = self.outputs_collection.get(provider_name) + if isinstance(val, tuple): + val = [val] + if exist_output is None: # add new entry + exist_output = list() + exist_output += (val) + self.outputs_collection[provider_name] = exist_output + + def collect_returns(self): + """ + Collect all nodes and submodules' returns in the module. + + Note: + The logic is to collect the return from nodes and submodules by the order + of topological index. + + For returns from a node, it will check if the return will be used externally. + If external (external means the successor a.k.a the return user has different scope with the node), + add this return to current module's outputs_collection, where + key is this node's original onnx_name, and value is a list of + tuple(successor_name, this node's opt_var_name) + + For returns from a submodule, it will check if the submodule has already collected returns, + If not, do it and then continue the following procedures. + Now we will check each element in submodule's outputs_collection. Note that we DO NOT check submodule's + returns should be continued returning, but just return them. + All these returns from submodules will be changes their original nodes output (a.k.a outputs provider) + `opt_var_name` to submodules' `opt_var_name`. + + Finally, we match the outputs and inputs in the current module level. + """ + for (_, struct) in self.get_generate_order(): + if isinstance(struct, NodeStruct): + outputs_list = [] + # add these successor nodes name to collection for future use + for succ in struct.successor_nodes_names: + outputs_list.append((succ, struct.ms_opt_var_name)) + if outputs_list: + self._append_to_outputs_collection(struct.onnx_name, outputs_list) + if isinstance(struct, ModuleStruct): + # Remove unnecessary returns, succ are all inside current + if not struct.outputs_collection: + struct.collect_returns() + sub_outputs_collection = struct.outputs_collection + # check each returns in sub + for provider_name, outputs_list in sub_outputs_collection.items(): + for output in outputs_list: + (succ, _) = output # (succ, provider_opt_var_name) in output + new_output = (succ, struct.ms_opt_var_name) + self._append_to_outputs_collection(provider_name, new_output) + self._match_sub_modules_inputs() + + def get_returned_opt_var_name(self) -> list: + """Return a list of returned output var of this module.""" + idx = 0 + added_to_return = set() + ret = [] + for ext_successor_requested, opt_var_name_in_this_module in self.external_successor_local_returns_map.items(): + if ext_successor_requested in added_to_return: + continue + ret.append((ext_successor_requested, opt_var_name_in_this_module, idx)) + added_to_return.add(ext_successor_requested) + return ret + + def reset_node_external_input_to_local(self, nd_struct): + """ + Reset node's input to module's construct args + """ + for prec_node_name in nd_struct.precursor_nodes_names_external: + if prec_node_name in self.onnx_names: # prec node in current module's. + continue + if prec_node_name in self.construct_header_x.values(): + # prec node assigned to construct header to passed in. + local_x = get_dict_key_by_value(prec_node_name, self.construct_header_x) + nd_struct.set_inputs_in_construct_header(local_x, prec_node_name) + else: # Extra precursor nodes, raise error + raise ValueError("Found external inputs of the Node but the module does not have it.") + + def reset_submodule_external_input_to_local(self, md_struct): + """ + Reset submodule's external input to current module. + + Args: + md_struct (ModuleStruct): The submodule in the current module. + """ + # check submodule's input + for _, submodule_precursor in md_struct.construct_header_x.items(): + if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return + # but do nothing here + continue + else: # if external, match with current module construct header x + if submodule_precursor in self.construct_header_x.values(): + local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) + md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) + else: # Extra precursor nodes, raise error + raise ValueError("Found external inputs of the submodule but the module does not have it.") + + def register_node_output_to_module(self, nd_struct): + """Register nodes outputs to this module's return.""" + for succ_node_name in nd_struct.successor_nodes_names_external: + self.external_successor_local_returns_map[succ_node_name] = nd_struct.ms_opt_var_name + + def register_submodule_output_to_module(self, md_struct): + """Register submodule outputs to this module's return.""" + submodule_returns = md_struct.get_returned_opt_var_name() + submodule_opt_var_name = md_struct.ms_opt_var_name + for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: + self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) + # edit external succ 's inputs in parent module + ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) + ext_node_parent = ext_node.parent_module_struct + while ext_node_parent != self.parent_module_struct: + ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name + ext_node_parent = ext_node_parent.parent_module_struct + + # need find the prec_name? + for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): + if isinstance(opt_var_name, str): + if opt_var_name == opt_var_name_in_this_module: + ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) + if isinstance(opt_var_name, tuple): + if opt_var_name[0] == opt_var_name_in_this_module: + ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py new file mode 100644 index 00000000..9764fe2f --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -0,0 +1,423 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define the NodeStruct which stores all info. of a node.""" +from collections import OrderedDict + +from .scope_utils import Scope +from .args_translator import ArgsTranslation +from ..common.code_fragment import CodeFragment +from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode +from ..third_party_graph.onnx_graph_node import OnnxGraphNode +from ..common.global_context import GlobalContext +from ..constant import InputType + + +class NodeStruct: + """ + Define a node struct which stores all info. to generate statement. + + Args: + args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. + + Note: + You can pass as many args as possible and the Node Struct will update + by arguments order. + """ + GLOBAL_CONTEXT_MGR = GlobalContext() + + def __init__(self, args): + # define attributes here + self._identifier = None + self._fragment = None + self._args_translator = None + self._parent_module_struct = None + self.topo_idx = None + self.node_type = None + self.onnx_name = None + self.onnx_op = None + self.graph_node_ref = None # Our defined GraphNode + self.scope_name = None + self.ms_var_name = None + self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) + self.ms_op = None + self.ready_to_generate = False + + self.ms_params = dict() # converted params from mapper + self.ms_settings = dict() + self.ms_weights = dict() + self.ms_inputs = OrderedDict() + + self.scope = None # Defined Scope class + self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use + self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name + self.matched_inputs = list() # Matched inputs will can be directly used by code line generation + + # initialize funcs. + for arg in args: + self.update(arg) + + def __repr__(self): + return str({ + "address": hex(id(self)), + "idx": self.topo_idx, + "identifier": self.identifier + }) + + def ori_topo_idx(self): + """Get the original topological index in the onnx graph.""" + ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') + self.onnx_name = ori_name + return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) + + def update_var_name(self, idx=None): + """ + Update the var_name of each node. + + Args: + idx (int): The index of the node in this module. + """ + if idx is not None: + self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) + elif self.topo_idx is not None: + self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) + else: + raise ValueError("Unable to update var name when topo_idx is None.") + self.ms_opt_var_name = self.ms_var_name + '_opt' + + def _update_basics_from_gn(self, gn): + """Update basic info from GraphNode.""" + self.graph_node_ref = gn + self.scope_name = gn.scope_name + + def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): + """Update basic info from PyTorchGraphNode.""" + self.node_type = "PyTorchGraphNode" + self._update_basics_from_gn(gn) + + def _update_from_onnx_gn(self, gn: OnnxGraphNode): + """Update basic info from OnnxGraphNode.""" + self.node_type = "OnnxGraphNode" + self._update_basics_from_gn(gn) + + def _update_from_mapper(self, d): + """Update info from mapper.""" + if d.get('op_name'): + self.ms_op = d.get('op_name') + if d.get('params'): + self.ms_params = d.get('params') + if d.get('settings'): + self.ms_settings = d.get('settings') + if d.get('weights'): + self.ms_weights = d.get('weights') + + def _update_from_fragment(self, frag: CodeFragment): + """Update info from CodeFragment.""" + self._fragment = frag + if frag.operation: + self.ms_op = frag.operation + idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count + self.update_var_name(idx=idx) + + def _set_scope_from_identifier(self): + """Set the Node scope from identifier.""" + parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) + self.scope = Scope(parsed_scope) + + def init_args_translator(self, translated_args: list): + """ + Initialize the ArgsTranslator for each Node. + + Args: + translated_args (list): The list of args should be translated to formal args. + """ + if not self._fragment: + raise ValueError("Initialize argument translator failed.") + if self._fragment.actual_args and translated_args: + self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) + + def check_if_generate_ready(self): + """Check if the NodeStruct is able to generate code.""" + # check essential params exists + if all([self.identifier, + self.node_type, + self.scope_name, + self.ms_var_name, + self.ms_opt_var_name, + self.ms_op]): + self.ready_to_generate = True + + def update(self, arg, force_ready=False): + """ + Pass Node info. to generator NodeStruct. + + Args: + arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. + force_ready (bool): Force this NodeStruct is ready to generate. + """ + if isinstance(arg, PyTorchGraphNode): + self._update_from_pytorch_gn(arg) + elif isinstance(arg, OnnxGraphNode): + self._update_from_onnx_gn(arg) + elif isinstance(arg, (dict, OrderedDict)): + self._update_from_mapper(arg) + elif isinstance(arg, CodeFragment): + self._update_from_fragment(arg) + else: + raise TypeError("NodeStruct received an unsupported initializing argument.") + + if force_ready: + self.ready_to_generate = True + else: + self.check_if_generate_ready() + + @property + def identifier(self): + """Return the identifier of the node.""" + return self._identifier + + @identifier.setter + def identifier(self, s): + """ + Set the Node identifier, and update the scope. + + Args: + s (str): The node identifier string. + """ + self._identifier = s + self._set_scope_from_identifier() + self.topo_idx = self.ori_topo_idx() + self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self + + @property + def fragment(self): + """Return the fragment of the node.""" + return self._fragment + + @fragment.setter + def fragment(self, frag): + """ + Set the Node fragment. + + Args: + s (NodeFragment): The node identifier string. + """ + self._fragment = frag + + @property + def graph_node(self): + """Return the GraphNode reference.""" + return self.graph_node_ref + + @graph_node.setter + def graph_node(self, graphnode): + """Set the GraphNode reference.""" + self.graph_node_ref = graphnode + + @property + def onnx_node(self): + """Return the original onnx node reference.""" + return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) + + @property + def args_translator(self): + """Return the args translator of this Node.""" + return self._args_translator + + @property + def precursor_nodes_names(self) -> list: + """Return the names of precursor nodes.""" + return self.graph_node_ref.precursor_nodes + + @property + def precursor_nodes_structs(self) -> list: + """Return the node struct instances of precursor nodes.""" + ret = [] + precursor_nodes_names = self.precursor_nodes_names + for pre_node_name in precursor_nodes_names: + nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + ret.append(nd_struct) + return ret + + @property + def successor_nodes_names(self) -> list: + """Return the names of successor nodes.""" + return self.graph_node_ref.successor_nodes + + @property + def successor_nodes_structs(self) -> list: + """Return the node struct instances of successor nodes.""" + ret = [] + for pre_node_name in self.successor_nodes_names: + nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + ret.append(nd_struct) + return ret + + @property + def parent_module_struct(self): + """Return the parent struct of this node.""" + return self._parent_module_struct + + @parent_module_struct.setter + def parent_module_struct(self, ref): + self._parent_module_struct = ref + + # Code Generation funcs below + + def code_line_in_init(self): + """Initialization line of code in module init block.""" + left = "self.{}".format(self.ms_var_name) + args_list = list() + if self._args_translator is not None: + args_list += self._args_translator.actual_args_to_str_list + args_list += self._args_translator.formal_args_to_str_list + else: + actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) + args_list += actual_args_str + right = f"{self.ms_op}({', '.join(args_list)})" + return left, right + + def _get_correct_in_module_returns(self, prec_node, in_module_return): + """ + Find the correct precursor node name in return statement of its parent module. + + Args: + prec_node (str): The onnx name of the precursor node given. + in_module_return (list[tuple]): The list of outputs which contains parent module identifier + and module opt_var_name. + + Return: + str, correct opt_var_name to be passed in current node. + """ + found_return = False + for ret in in_module_return: + (md_identifier, input_name_to_use) = ret + p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) + # recursive check the p node parent + parent = p_node_struct + while not found_return: + parent = parent.parent_module_struct + if parent is None: + break + if parent.identifier == md_identifier: + return input_name_to_use + return None + + def code_line_in_construct(self, inputs=None, in_module_returns=None): + """Construct line of code in module construct block. """ + left = self.ms_opt_var_name + if inputs is None: + inputs = [] + for idx, prec_node in enumerate(self.precursor_nodes_names): + if self.inputs_in_construct_header.get(prec_node): + inputs.append(self.inputs_in_construct_header.get(prec_node)) + elif self._check_target_node_internal(prec_node): + inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) + elif self.inputs_in_parent_module.get(prec_node): + inputs.append(self.inputs_in_parent_module.get(prec_node)) + elif in_module_returns and in_module_returns.get(self.onnx_name) \ + and (not self._check_target_node_internal(prec_node)): + inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) + else: + inputs.append("unk_{}_{}".format(idx, prec_node)) + + if self.matched_inputs: + inputs = self.matched_inputs + + # Check original onnx node's input to ensure double inputs are not ignored + original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) + new_inputs = [] + for idx, prec_node in enumerate(self.precursor_nodes_names): + occurence = original_inputs.count(prec_node) + for _ in range(occurence): + new_inputs.append(inputs[idx]) + inputs = new_inputs + + if isinstance(inputs, str): + inputs = [inputs] + + if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: + inputs = [str(tuple(inputs)).replace("\'", "")] + + if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: + for _, val in self._fragment.code_setting.op_extra_input.items(): + inputs.append(str(val)) + + if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: + inputs.append(f"self.{self.ms_var_name}_w") + right = f"self.{self.ms_var_name}({', '.join(inputs)})" + return left, right + + def add_extra_tensor(self): + """ Add extra tensor.""" + left = "self.{}_w".format(self.ms_var_name) + shape = self._fragment.code_setting.op_extra_tensor.shape + right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)" + return left, right + + # The following functions are specified for multiple in/out support. + # and should be called only after generator._recursive_form_modules() + + def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): + """ + Mark the registered external inputs for code generation. + + Note: + This function to be called by its parent (ModuleStruct). + + Args: + header_x (str): The `x` in module construct header. + onnx_precursor_node_name (str): The original onnx node name. + """ + if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: + raise ValueError("The input from {} has already registered. Check this node \ + {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) + self.inputs_in_construct_header[onnx_precursor_node_name] = header_x + + def _check_target_node_internal(self, name: str) -> bool: + """ + Check given node under the same scope. + + Args: + name (str): Can accept both node identifier or original onnx node name. + """ + target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ + or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) + if target_nd_struct is None and self.topo_idx == 0: # First node always has external input + return False + + if target_nd_struct is None: + raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) + return target_nd_struct.scope.path == self.scope.path + + @property + def has_successor_node_external(self) -> bool: + """Check if any successor_node is in external module.""" + for name in self.successor_nodes_names: + if not self._check_target_node_internal(name): + return False + + return True + + @property + def precursor_nodes_names_external(self) -> list: + """Return a list of external precursor nodes names.""" + return [name for name in self.precursor_nodes_names \ + if not self._check_target_node_internal(name)] + + @property + def successor_nodes_names_external(self) -> list: + """Return a list of external successor nodes names.""" + return [name for name in self.successor_nodes_names \ + if not self._check_target_node_internal(name)] diff --git a/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py b/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py new file mode 100644 index 00000000..43b8c28b --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define a scope class processing all operations related to scope and scope name.""" +import re + + +class Scope(): + """Define scope related operations.""" + + def __init__(self, scope_str): + scopes = scope_str.split('/') + self.module_path = list() + self.scope_list = scopes[:-1] + self.head = self.scope_list[0] + self.tail = self.scope_list[-1] + self.initialization() + + def initialization(self): + """Init scope class.""" + self._update_module_path_from_scope_list() + + def _update_module_path_from_scope_list(self): + """Update the module scope path from a list of scope.""" + self.module_path = list() + for scope in self.scope_list: + if scope == 'Model': + continue + + if 'Module' in scope: + regex = r"Module(?P\d+)_(?P\d+)" + match = re.match(regex, scope) + if match: + module_num = match.group('num') + uid = match.group('curr_level_unique_id') + self.module_path.append((int(module_num), int(uid))) + + @property + def path(self): + """Return module scope path.""" + return self.module_path + + def set_path(self, ind, path_tuple: tuple): + """ + Set the module scope path. + + Args: + ind (int): The index of the scope path to be set. + path_tuple ((int, int)): The tuple of the scope path. + """ + self.module_path[ind] = path_tuple + + @property + def to_str(self): + """Return the full module scope as the string format.""" + full_str_list = ["Model"] + for (num, uid) in self.module_path: + local = "Module{}_{}".format(num, uid) + full_str_list.append(local) + + return "/".join(full_str_list) + + @property + def depth(self): + """Return the depth of the scope path.""" + return len(self.path) + + @staticmethod + def scope_to_module_name(path): + """ + Helper function to convert any scope path string to the full module scope. + + Args: + path (str): path string like "[(5, 0), (3, 0)]" + + Returns: + str, the full module scope with format like "Model/Module5_0/Module3_0/" + """ + scope_str_list = ["Model"] + if isinstance(path, str): + path = Scope.path_str_to_list(path) + if isinstance(path, list): + for (num, uid) in path: + local = "Module{}_{}".format(num, uid) + scope_str_list.append(local) + + return "/".join(scope_str_list) + + @staticmethod + def parse_scope_from_node_identifier(node_identifier: str): + """ + Helper function to parse the scope string from node identifier. + + Args: + node_identifier (str): The string of the node identifier. + + Returns: + str, parsed scope string from node identifier. + """ + regex = r"(?PModel/.*)\$\S+\$" + match = re.match(regex, node_identifier) + if not match: + return None + return match.group('scope') + + @staticmethod + def path_str_to_list(scope_path_str: str): + """ + Helper function to convert the scope path string back to list. + + Args: + scope_path_str (str): The scope path string like "[(5, 0), (3, 0)]". + + Returns: + list, a list of the scope path like [(5, 0), (3, 0)]. + """ + ret = [] + tmp = scope_path_str.strip('[').strip(']') + regex = r"\((?P\d+), (?P\d+)\)" + s_all = re.findall(regex, tmp) + for (num, uid) in s_all: + ret.append((int(num), int(uid))) + + return ret + + @staticmethod + def get_parent_module_num_and_uid(path): + """ + Helper function to return its parent's scope tuple. + + Args: + path (Union[str, list]): Module scope path string. e.g. "[(5, 0), (3, 0)]" + + Returns: + tuple, parent's scope level. e.g. [(5, 0)] + """ + if isinstance(path, str): + path = Scope.path_str_to_list(path) + if isinstance(path, list): + if len(path) == 1: # modules under the main module, (-1, -1) means main module. + return (-1, -1) + if len(path) > 1: # modules under another non-main module. Return parent's scope. + parent = path[-2] + return parent + + return None diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py index 2811f762..f43e7a0a 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py @@ -106,3 +106,51 @@ class GlobalVarNameMgr: global_var_namespace.add(new_name) return new_name + + +class LocalVarNameMgr: + """Local variable name mgr.""" + + def __init__(self): + self.local_op_namespace = dict() + self.local_var_namespace = set() + + @staticmethod + def _get_name(name): + """Deal with op name.""" + if "::" in name: + return name.split("::")[1] + return name + + def get_name(self, op_type): + """ + Get module/variable name. + + If the module already existed, then add a suffix to it. + + conv1 onnx::conv + + Args: + op_type (str): Operator type in onnx. + + Returns: + str, module name. + """ + + def _gen(t): + t = t.lower() + if t not in self.local_op_namespace: + self.local_op_namespace[t] = START_IDX + suffix = "" + else: + self.local_op_namespace[t] += 1 + suffix = f"{self.local_op_namespace[t] - 1}" + + return f"{self._get_name(t)}{suffix}" + + new_name = _gen(op_type) + while new_name in self.local_var_namespace: + new_name = _gen(op_type) + + self.local_var_namespace.add(new_name) + return new_name diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 047556fa..3ccf088e 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -151,7 +151,7 @@ class OnnxGraph(Graph): input_shape (tuple): Input shape. """ input_node = InputNode(input_shape) - input_node_name = "{}InputNode" + input_node_name = self._raw_input_nodes.replace(":0", "") for node_name, node in self._nodes_collection.items(): if node_name in self._input_nodes: ipt_nd_name = input_node_name.format(input_node.scope_name) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 6ba67c84..f76a371a 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -23,6 +23,7 @@ import numpy as np from mindinsight.mindconverter.common.log import logger as log from ..common.utils import fetch_output_from_onnx_model +from ..common.global_context import GlobalContext from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL @@ -110,6 +111,7 @@ class OnnxTensor: self.to_nodes = [] def to_array(self): + """Convert the tensor value from binary to np array.""" onnx = import_module("onnx") # Convert binary data to np.array if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): @@ -264,7 +266,7 @@ class OnnxDataLoader: self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] # args for init self._is_infer_shape = infer_shape - + self._global_context = GlobalContext() # params parsed in init self.inferred_model = None @@ -375,12 +377,19 @@ class OnnxDataLoader: def _parse_nodes(self): """Parse each onnx nodes in the model.""" - for node in self.nodes: + nodes_topo_idx = [] + for idx, node in enumerate(self.nodes): n = OnnxNode(node) self._nodes_dict[n.name] = n + nodes_topo_idx.append((idx, n.name)) if len(node.output) > 1: raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") self.output_name_to_node_name[node.output[0]] = node.name + self._global_context.onnx_node_name_to_topo_idx[n.name] = idx + node_inputs = [i.replace(":0", "") for i in node.input] + self._global_context.onnx_node_inputs[n.name] = node_inputs + self._global_context.onnx_nodes_collection = self._nodes_dict + self._global_context.onnx_nodes_topo_index = nodes_topo_idx def _parse_tensors(self): """Parse each onnx tensors in the model.""" @@ -388,6 +397,7 @@ class OnnxDataLoader: for tensor in tensors: t = OnnxTensor(tensor) self.tensors_dict[t.name] = t + self._global_context.onnx_tensors_collection = self.tensors_dict def _parse_node_output_shape(self): """