From: @liangtianshu Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -310,11 +310,14 @@ class NewFragment: | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | |||
| rewrite_params = { | |||
| f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot] | |||
| f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value].get(slot) | |||
| for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] | |||
| } | |||
| rewrite_data.update(rewrite_params) | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) | |||
| template = template.format(**{ | |||
| k: str(rewrite_data[k]) for k in rewrite_data | |||
| }) | |||
| return template.format(**{ | |||
| k: str(rewrite_data[k]) for k in rewrite_data | |||
| }) | |||
| @@ -83,6 +83,9 @@ class GlobalContext(metaclass=Singleton): | |||
| # Record weights name that used many times. | |||
| self.repeated_weights = dict() | |||
| self.repeated_weights_declaration = dict() | |||
| # Define Module Struct Build Status | |||
| self.build_struct_finished = False | |||
| def get_onnx_node_from_identifier(self, identifier): | |||
| """Return an OnnxUtils defined node by its identifier.""" | |||
| @@ -144,7 +147,7 @@ class GlobalContext(metaclass=Singleton): | |||
| @property | |||
| def onnx_tensors_collection(self): | |||
| """Return the onnx tensors collection.""" | |||
| return self.onnx_tensors_collection | |||
| return self._onnx_tensors_collection | |||
| @onnx_tensors_collection.setter | |||
| def onnx_tensors_collection(self, arg): | |||
| @@ -17,6 +17,9 @@ import abc | |||
| import copy | |||
| from typing import Union, Iterable | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| class BaseOutput: | |||
| """ | |||
| Define the class of output providing a universal nodes' and modules' output data collection. | |||
| @@ -36,6 +39,9 @@ class BaseOutput: | |||
| # The following attributes to be set by who referenced this object. | |||
| self.onnx_edge_name = None | |||
| self.to_external = False | |||
| self.opt_var_name = None | |||
| # Only for module output edge and its name inside its module | |||
| self.inner_ret_name = None | |||
| @property | |||
| def ms_user(self): | |||
| @@ -59,28 +65,41 @@ class BaseOutputManager(abc.ABC): | |||
| Args: | |||
| output_mappings (list): A list of output mapping. | |||
| """ | |||
| def __init__(self, output_mappings): | |||
| if isinstance(self.__class__, ModuleOutputManager): | |||
| def __init__(self, identifier, output_mappings: Iterable): | |||
| if isinstance(self, ModuleOutputManager): | |||
| return | |||
| self._base_output_list = list() | |||
| self._base_output_dict = dict() | |||
| self.identifier = identifier | |||
| # init base output obj | |||
| for mapping in output_mappings: | |||
| for (onnx_edge_name, mapping) in output_mappings: | |||
| obj = BaseOutput(mapping) | |||
| self._base_output_list.append(obj) | |||
| self._base_output_dict[onnx_edge_name] = obj | |||
| obj.onnx_edge_name = onnx_edge_name | |||
| @property | |||
| def outputs(self): | |||
| """Return the list of BaseOutput in this manager.""" | |||
| return self._base_output_list | |||
| return self._base_output_dict.values() | |||
| @property | |||
| def outputs_edges(self): | |||
| """Return the list of outputs edge names in this manager.""" | |||
| return self._base_output_dict.keys() | |||
| @outputs.setter | |||
| def outputs(self, val: list): | |||
| """Set the list of BaseOutput in this manager.""" | |||
| tmp = dict() | |||
| for v in val: | |||
| if not isinstance(v, BaseOutput): | |||
| raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.") | |||
| self._base_output_list = val | |||
| tmp[v.onnx_edge_name] = v | |||
| self._base_output_dict = tmp | |||
| def get_base_out(self, onnx_edge_name: str) -> BaseOutput: | |||
| """Return the BaseOut by key.""" | |||
| return self._base_output_dict.get(onnx_edge_name) | |||
| @abc.abstractmethod | |||
| def deepcopy(self): | |||
| @@ -88,7 +107,7 @@ class BaseOutputManager(abc.ABC): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.outputs = list() | |||
| for out in self._base_output_list: | |||
| for out in self._base_output_dict.values(): | |||
| result.outputs.append(out.deepcopy()) | |||
| return result | |||
| @@ -102,14 +121,19 @@ class NodeOutputManager(BaseOutputManager): | |||
| output_mappings (list): A list of the output mapping. | |||
| """ | |||
| def __init__(self, identifier, output_mappings=None) -> None: | |||
| super(NodeOutputManager, self).__init__(output_mappings) | |||
| self.identifier = identifier | |||
| super(NodeOutputManager, self).__init__(identifier, output_mappings) | |||
| def deepcopy(self): | |||
| """Self defined deepcopy method.""" | |||
| new_mgr = super().deepcopy() | |||
| new_mgr.identifier = self.identifier | |||
| return new_mgr | |||
| def bind_opt_var_names(self, fragment: NewFragment): | |||
| """Get the opt_var_name in return statement.""" | |||
| for base_out in self._base_output_dict.values(): | |||
| base_out.opt_var_name = fragment.get_outputs_by_idx(base_out.idx_in_ms_provider) | |||
| class ModuleOutputManager(BaseOutputManager): | |||
| """ | |||
| @@ -120,14 +144,13 @@ class ModuleOutputManager(BaseOutputManager): | |||
| output_mappings (list): a list of output mapping | |||
| """ | |||
| def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None: | |||
| super(ModuleOutputManager, self).__init__(None) | |||
| self.identifier = identifier | |||
| super(ModuleOutputManager, self).__init__(identifier, None) | |||
| self._return_list_counter = 0 | |||
| self._base_output_list = list() | |||
| self._base_output_dict = dict() | |||
| if isinstance(base_out, BaseOutput): | |||
| self._base_output_list.append(base_out) | |||
| self.outputs = [base_out] | |||
| else: | |||
| self._base_output_list += base_out | |||
| self.outputs = base_out | |||
| @property | |||
| def return_num(self): | |||
| @@ -139,6 +162,12 @@ class ModuleOutputManager(BaseOutputManager): | |||
| """Set the number of outputs to be returned.""" | |||
| self._return_list_counter = num | |||
| def assign_opt_var_name_to_each_output(self, opt_var_name_base: str): | |||
| """Assign opt_var_name for each output.""" | |||
| for idx, base_out in enumerate(self._base_output_dict.values()): | |||
| postfix = str(idx) if idx > 0 else "" | |||
| base_out.opt_var_name = '_'.join([opt_var_name_base, postfix]) if idx > 0 else opt_var_name_base | |||
| def deepcopy(self): | |||
| """Return a deepcopy of current instance.""" | |||
| new_mgr = super().deepcopy() | |||
| @@ -146,6 +175,30 @@ class ModuleOutputManager(BaseOutputManager): | |||
| new_mgr.return_num = self._return_list_counter | |||
| return new_mgr | |||
| def bind_module_outputs_internal_name(self, outputs_register: dict): | |||
| """ | |||
| Get the opt_var_name in return list. | |||
| Args: | |||
| opt_var_name_list (list): List from module outputs register, registered by submodule and nodes. | |||
| """ | |||
| for base_out in self._base_output_dict.values(): | |||
| # bind the edge name inside module | |||
| base_out.inner_ret_name = outputs_register.get(base_out.onnx_edge_name) | |||
| def bind_opt_var_name(self, opt_var_names: list): | |||
| """ | |||
| Assign the opt_var_name for outputs of this module. | |||
| Args: | |||
| opt_var_names (list): A list of opt_var_name of this module, generated by module itself. | |||
| """ | |||
| if len(opt_var_names) != len(self._base_output_dict.values()): | |||
| raise ValueError(f"Unable to bind the opt_var_name of the Module {self.identifier}" \ | |||
| f" has inconsistent outputs number.") | |||
| for idx, base_out in enumerate(self._base_output_dict.values()): | |||
| base_out.opt_var_name = opt_var_names[idx] | |||
| class OutputStorage: | |||
| """A class saves all outputs.""" | |||
| @@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"] | |||
| import re | |||
| import copy | |||
| from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator | |||
| def _tf_model_node_name_reformat(node, node_name): | |||
| @@ -123,6 +123,7 @@ def _convert_params(node, mapper, external_inputs): | |||
| params=params, | |||
| weights=node.weight) | |||
| exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs) | |||
| outputs_order_mapping = _bind_outputs_edges(exchange_msg=exchange_msg, outputs_order_mapping=outputs_order_mapping) | |||
| return code_template, exchange_msg, outputs_lst, outputs_order_mapping | |||
| @@ -145,3 +146,22 @@ def _combine_external_inputs_with_precursor_nodes(node, external_inputs): | |||
| node_idx = node.ir_node_inputs.index(item) | |||
| precursor.insert(node_idx, item) | |||
| return precursor | |||
| def _bind_outputs_edges(exchange_msg, outputs_order_mapping): | |||
| """ | |||
| Bind the outputs edges names with the outputs order mapping. | |||
| Args: | |||
| exchange_msg (dict): The dict of exchange messages of this node. | |||
| outputs_order_mapping (tuple): The outputs mapping of this node. | |||
| returns, | |||
| zip, the zip object of both edges and mapping | |||
| """ | |||
| outputs_edges = exchange_msg.get('metadata').get('outputs') | |||
| if not outputs_edges: | |||
| raise ValueError(f"ONNX Node {exchange_msg.get('metadata').get('source')} has no outputs info.") | |||
| if len(outputs_edges) != len(outputs_order_mapping): | |||
| raise ValueError(f"ONNX Node {exchange_msg.get('metadata').get('source')} has inconsistent " \ | |||
| f"outputs edge number and mapping number") | |||
| return zip(outputs_edges, outputs_order_mapping) | |||
| @@ -17,6 +17,7 @@ import copy | |||
| from collections import OrderedDict | |||
| from importlib import import_module | |||
| import numpy as np | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||
| @@ -32,6 +33,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, S | |||
| FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME | |||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | |||
| from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher | |||
| from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper | |||
| class CodeStruct: | |||
| @@ -90,6 +93,10 @@ class CodeStruct: | |||
| for formal in md_struct.args_translator.formal_args.keys(): | |||
| module_def_args.append(formal) | |||
| # set passthrough weights for shared weights, no need for main model | |||
| if md_struct.identifier != []: | |||
| module_def_args = SharedWeightHelper.add_shared_weights_in_init_statement(md_struct, module_def_args) | |||
| # For code line in init & construct blocks | |||
| init_lines = list() | |||
| cons_lines = list() | |||
| @@ -105,7 +112,7 @@ class CodeStruct: | |||
| init_lines += init_str | |||
| cons_lines += cons_str | |||
| else: # is ModuleStruct | |||
| else: # is ModuleStruct | |||
| # check if this instance generated CodeStruct | |||
| if GlobalContext().code_structs.get(struct.pattern_id) is None: | |||
| CodeStruct(struct, repeated_submodules) | |||
| @@ -118,6 +125,13 @@ class CodeStruct: | |||
| # define header of init block | |||
| self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" | |||
| #add shared weights declaration in init code part | |||
| if md_struct.identifier == []: | |||
| passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct) | |||
| for s in passthrough_w_declaration: | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}{s}" | |||
| # add init code lines to code line list. | |||
| self.code_line_list += init_lines | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| @@ -129,16 +143,14 @@ class CodeStruct: | |||
| self.code_line_list += cons_lines | |||
| # define returns | |||
| returns = [] | |||
| if md_struct.external_successor_local_returns_map: | |||
| for r in list(md_struct.external_successor_local_returns_map.values()): | |||
| if isinstance(r, tuple): # results return with index nth output | |||
| returns.append(r[0]) | |||
| else: | |||
| returns.append(r) | |||
| returns = list(set(returns)) | |||
| else: | |||
| returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ | |||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||
| # take opt_var_name to return_list | |||
| for output_edge in md_struct.outputs_register.keys(): | |||
| opt_var_name = md_struct.internal_outputs_collection.get(output_edge) | |||
| if opt_var_name is None: | |||
| raise ValueError(f"Module {md_struct.identifier} has an output {output_edge} has unknown opt_var_name.") | |||
| returns.append(opt_var_name) | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| GlobalContext().code_structs[md_struct.pattern_id] = self | |||
| @@ -244,6 +256,83 @@ class Generator: | |||
| return formal_args | |||
| @staticmethod | |||
| def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list): | |||
| """Set the weight with given param postfix to args translation.""" | |||
| for _, nd_struct in nd_struct_list: | |||
| nparr = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix).get('data') | |||
| nd_struct.fragment.default_var["args"][f"{t_param_postfix}_shape"] = nparr.shape | |||
| nd_struct.fragment.default_var["args"][f"{t_param_postfix}_dtype"] = nparr.dtype | |||
| init_tensor_template = f"Parameter(Tensor(np.random.uniform(0, 1, "\ | |||
| f"{{{t_param_postfix}_shape}}).astype(np.{{{t_param_postfix}_dtype}})), "\ | |||
| f"name=None)" | |||
| nd_struct.fragment.default_var["parameters"][t_param_postfix] = init_tensor_template | |||
| def _get_same_trainable_params_onnx_name_from_repeated_nodes(self, | |||
| t_param_postfix, | |||
| t_param_data_dict, | |||
| nd_struct_list: list): | |||
| """Return all onnx names from the same weights in repeated nodes.""" | |||
| (_, base_nd_struct) = nd_struct_list[0] | |||
| t_base_name = t_param_data_dict.get('onnx_name') | |||
| t_onnx_names = [t_base_name] | |||
| for (_, nd_struct) in nd_struct_list[1:]: | |||
| compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix) | |||
| if not compared_t_param_data_dict: | |||
| raise ValueError(f"Inconsistent trainable params detected for node "\ | |||
| f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}") | |||
| compared_t_name = compared_t_param_data_dict.get('onnx_name') | |||
| t_onnx_names.append(compared_t_name) | |||
| return t_onnx_names | |||
| def _partial_shared_weights_in_repeated_submodule_procs(self, nd_struct_list): | |||
| """ | |||
| Check each node in repeated submodule to ensure the node has a fully / partial shared weight. | |||
| Args: | |||
| nd_struct_list (list): A list of node structs which are same node in repeated modules. | |||
| """ | |||
| # Not repeated will skip this function | |||
| if len(nd_struct_list) < 2: | |||
| return | |||
| (_, base_nd_struct) = nd_struct_list[0] | |||
| shared_w_list = self._global_context.repeated_weights.keys() | |||
| if not shared_w_list: | |||
| if base_nd_struct.fragment.default_var.get("parameters"): | |||
| # set only if has parameters as it requires rewritten. | |||
| for (t_param_postfix, t_param_data_dict) in \ | |||
| base_nd_struct.fragment.default_var["trainable_params"].items(): | |||
| if not isinstance(t_param_data_dict.get('data'), np.ndarray): | |||
| continue | |||
| Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) | |||
| return | |||
| for (t_param_postfix, t_param_data_dict) in base_nd_struct.fragment.default_var["trainable_params"].items(): | |||
| # check each weight if partial shared or fully shared weight | |||
| if not t_param_data_dict: | |||
| continue | |||
| t_onnx_names = self._get_same_trainable_params_onnx_name_from_repeated_nodes(t_param_postfix, | |||
| t_param_data_dict, | |||
| nd_struct_list) | |||
| t_shared_status = [name in shared_w_list for name in t_onnx_names] | |||
| if True in t_shared_status and False in t_shared_status: | |||
| # is partial shared, set unshared to fake shared in GlobalContext | |||
| for idx, (name, status) in enumerate(zip(t_onnx_names, t_shared_status)): | |||
| if status: | |||
| # actual shared, do nothing, skip | |||
| continue | |||
| node_onnx_name = nd_struct_list[idx][1].onnx_name | |||
| if not self._global_context.repeated_weights.get(name): | |||
| self._global_context.repeated_weights[name] = [node_onnx_name] | |||
| else: | |||
| self._global_context.repeated_weights[name] += [node_onnx_name] | |||
| if True not in t_shared_status and base_nd_struct.fragment.default_var.get("parameters"): | |||
| # if the repeated node is not shared weight and the mapper accept parameters rewritten. | |||
| if not isinstance(t_param_data_dict.get('data'), np.ndarray): | |||
| continue | |||
| Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) | |||
| def _list_formal_parameters_in_a_module(self, module_filter_return): | |||
| """ | |||
| Find all formal args / params from nodes in a module. | |||
| @@ -256,6 +345,9 @@ class Generator: | |||
| """ | |||
| formal_params_list = list() | |||
| transposed = [list(e) for e in zip(*module_filter_return)] | |||
| for operation in transposed: | |||
| # use the map filtered result for partial shared weights procs | |||
| self._partial_shared_weights_in_repeated_submodule_procs(operation) | |||
| for operation in transposed: | |||
| formal_parameters = self._compare_with_base_parameters(operation) | |||
| if formal_parameters: | |||
| @@ -363,19 +455,34 @@ class Generator: | |||
| md_collection_len = new_len | |||
| else: | |||
| len_changes = False | |||
| GlobalContext().build_struct_finished = True | |||
| # 5. Update all translated args from module map | |||
| self._update_all_modules_args_translator() | |||
| # 6. Update all nodes and moudles input/output | |||
| # Enable build_output_connections later. | |||
| self.build_outputs_connection() | |||
| self.module_structs.get('[]').allocate_construct_header_x() | |||
| self.module_structs.get('[]').collect_returns() | |||
| matcher = MatcherLauncher(self.module_structs.get('[]')) | |||
| matcher.matching_process() | |||
| for nd_struct in self.node_structs.values(): | |||
| if nd_struct.fragment.metadata.get("operation") == "Split": | |||
| self._split_op_procs(nd_struct) | |||
| def _shared_weights_processing(self): | |||
| """Process shared weights.""" | |||
| # check each node has shared weight | |||
| for nd_struct in self.node_structs.values(): | |||
| shared_weights = SharedWeightHelper.check_node_has_shared_weight(nd_struct) | |||
| if shared_weights: | |||
| # register each shared weight to public module | |||
| for shared_w in shared_weights: | |||
| SharedWeightHelper.register_shared_weight_to_public_parent(nd_struct, | |||
| shared_w, | |||
| pub_module_identifier=[]) | |||
| def _update_all_modules_args_translator(self): | |||
| """Update all modules' args translators.""" | |||
| done_submodule = set() | |||
| @@ -426,7 +533,7 @@ class Generator: | |||
| else: | |||
| raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") | |||
| def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None): | |||
| def add_node(self, node_identifier, node_instance=None, node_fragment=None): | |||
| """ | |||
| Add Node information to the generator. | |||
| @@ -434,7 +541,6 @@ class Generator: | |||
| node_identifier (str): The unique identifier for the node passed in. | |||
| node_instance (GraphNode): The GraphNode instance of each node. | |||
| node_fragment (NodeFragment): The NodeFragment instance of this node passed in. | |||
| mapper_dict (dict): The dict contains converted attributes from mapper. | |||
| """ | |||
| if node_identifier is None: | |||
| @@ -443,8 +549,6 @@ class Generator: | |||
| args = [] | |||
| if node_instance is not None: | |||
| args.append(node_instance) | |||
| if mapper_dict is not None: | |||
| args.append(mapper_dict) | |||
| if node_fragment is not None: | |||
| args.append(node_fragment) | |||
| @@ -551,6 +655,7 @@ class Generator: | |||
| """ | |||
| self._form_bottom_submodule() | |||
| self._recursive_form_module() | |||
| self._shared_weights_processing() | |||
| ckpt_data_list, weight_map = self.generate_checkpoint() | |||
| @@ -654,14 +759,13 @@ class Generator: | |||
| # for each output in curr node output manager | |||
| for out in nd_struct.outputs_manager.outputs: | |||
| # Set the onnx output edge name to this output | |||
| out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider] | |||
| self._global_context.outputs_storage.add_output(out) | |||
| self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name, | |||
| nd_struct.fragment.metadata.get('source')) | |||
| self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier) | |||
| # Set input with existing output mapping | |||
| for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')): | |||
| for idx, inp in enumerate(nd_struct.inputs_edges_names): | |||
| if inp in self._global_context.outputs_storage.outputs_collections: | |||
| output_obj = self._global_context.outputs_storage.outputs_collections[inp] | |||
| output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx | |||
| @@ -694,8 +798,10 @@ class Generator: | |||
| self._collect_output_mgr(module=struct) | |||
| for out in struct.outputs_manager.outputs: | |||
| if Generator.check_output_need_to_external(root_module, out): | |||
| output_mgr_list.append(out) | |||
| root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list) | |||
| output_mgr_list.append(out.deepcopy()) | |||
| root_module.outputs_manager = ModuleOutputManager(root_module.identifier, | |||
| base_out=output_mgr_list) | |||
| root_module.outputs_manager.assign_opt_var_name_to_each_output(root_module.ms_opt_var_name) | |||
| @staticmethod | |||
| def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput): | |||
| @@ -0,0 +1,172 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Processing the node's and modules' inputs and outputs matching.""" | |||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| class MatcherHelper: | |||
| """ | |||
| Helper function for matching processing. | |||
| """ | |||
| @staticmethod | |||
| def main_model_special_process_inputs(main_model: ModuleStruct): | |||
| """Call in preprocess""" | |||
| # allocate main model construct x | |||
| prec_edges = main_model.external_precursor_nodes_names | |||
| default_x_str = "x" | |||
| inputs = dict() | |||
| for idx, edge in enumerate(prec_edges): | |||
| if not edge in inputs: | |||
| # idx-1 here as we have a x without index and another x0 in module inputs | |||
| # so the idx 0 position is the second input, not the first x. | |||
| inputs[edge] = "".join([default_x_str, str(idx-1)]) if idx > 0 else default_x_str | |||
| main_model.inputs_register = inputs | |||
| @staticmethod | |||
| def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct): | |||
| """Return the public parent module of both Node A and Node B.""" | |||
| find = False | |||
| b_onnx_name = node_b.onnx_name | |||
| tmp = node_a | |||
| while not find: | |||
| parent_struct = tmp.parent_module_struct | |||
| if b_onnx_name in parent_struct.onnx_names: | |||
| find = True | |||
| tmp = parent_struct | |||
| return tmp | |||
| @staticmethod | |||
| def get_submodule_has_out_user_under_public_parent(public_module: ModuleStruct, node_out_user: NodeStruct): | |||
| """Return the ModuleStruct which under the public module and contains the NodeStruct which provided.""" | |||
| for module_struct in public_module.module_structs: | |||
| if node_out_user.onnx_name in module_struct.onnx_names: | |||
| return module_struct | |||
| return None | |||
| @staticmethod | |||
| def register_outputs_to_main_model(output_edge_name: str, output_edge_provider: NodeStruct): | |||
| """ | |||
| Register the output edge to the main model and through all modules. | |||
| Args: | |||
| output_edge_name (str): The name of this output edge. | |||
| output_edge_provider (NodeStruct): The node which produces this output. | |||
| """ | |||
| base_out = output_edge_provider.outputs_manager.get_base_out(output_edge_name) | |||
| nd_parent = output_edge_provider.parent_module_struct | |||
| while nd_parent: | |||
| nd_parent.add_outputs_edge(base_out.onnx_edge_name) | |||
| nd_parent = nd_parent.parent_module_struct | |||
| @staticmethod | |||
| def register_inputs_to_main_model(input_edge_name: str, input_edge_user: NodeStruct): | |||
| """ | |||
| Register the input edge to the main model and through all modules. | |||
| Args: | |||
| input_edge_name (str): The name of this input edge. | |||
| input_edge_user (NodeStruct): The node uses this input. | |||
| """ | |||
| nd_parent = input_edge_user.parent_module_struct | |||
| while nd_parent: | |||
| nd_parent.add_inputs_edge(input_edge_name) | |||
| nd_parent = nd_parent.parent_module_struct | |||
| class MatcherLauncher: | |||
| """Process Node-to-Node inputs outputs matching.""" | |||
| def __init__(self, main_model: ModuleStruct): | |||
| super(MatcherLauncher).__init__() | |||
| self.main_model = main_model | |||
| self._global_context = GlobalContext() | |||
| self._graph_inputs = self._global_context.onnx_graph_info.get("graph_inputs") | |||
| self._graph_outputs = self._global_context.onnx_graph_info.get("graph_outputs") | |||
| def matching_process(self): | |||
| """The matching process.""" | |||
| # 0. Pre-process | |||
| MatcherHelper.main_model_special_process_inputs(self.main_model) | |||
| # 1. Set all module's return dict | |||
| self._register_module_inputs_x_header() | |||
| # 2. Set module returns | |||
| self._register_module_returns() | |||
| def _register_module_inputs_x_header(self): | |||
| """Recursively register the inputs to module init header.""" | |||
| # Use nearest parent module algorithm | |||
| for nd_struct in self._global_context.node_struct_collections.values(): | |||
| if not nd_struct.precursor_nodes_names_external: | |||
| # has no precursor nodes but need check if inputs are graph level inputs | |||
| has_graph_input = False | |||
| for edge in nd_struct.inputs_edges_names: | |||
| if edge in self._global_context.onnx_graph_info.get('graph_inputs'): | |||
| has_graph_input = True | |||
| break | |||
| if not has_graph_input: | |||
| continue # avoid unnecessary checking | |||
| for inp in nd_struct.inputs_edges_names: | |||
| if inp in self._global_context.onnx_graph_info.get('graph_inputs'): | |||
| # when the input edge is from graph level. | |||
| MatcherHelper.register_inputs_to_main_model(inp, nd_struct) | |||
| continue | |||
| out_provider_onnx_name = self._global_context.outputs_storage.onnx_name(inp) | |||
| out_provider_struct = \ | |||
| self._global_context.onnx_node_name_to_node_struct_map.get(out_provider_onnx_name) | |||
| if out_provider_struct is None: | |||
| raise ValueError(f"The Matcher detected an output has unknown provider for the edge {inp}") | |||
| public_parent = MatcherHelper.get_public_parent_module(nd_struct, out_provider_struct) | |||
| nd_parent = nd_struct.parent_module_struct | |||
| # Recursively register x in all parents until the public module | |||
| while public_parent.identifier != nd_parent.identifier: | |||
| nd_parent.add_inputs_edge(inp) | |||
| nd_parent = nd_parent.parent_module_struct | |||
| def _register_module_returns(self): | |||
| """Recursively register the node outputs to parent modules.""" | |||
| # Use nearest parent module algorithm | |||
| for nd_struct in self._global_context.node_struct_collections.values(): | |||
| if not nd_struct.successor_nodes_names_external: | |||
| # check if any edge to graph output | |||
| has_graph_output = False | |||
| for edge in nd_struct.fragment.metadata.get('outputs'): | |||
| if edge in self._global_context.onnx_graph_info.get('graph_outputs'): | |||
| has_graph_output = True | |||
| break | |||
| if not has_graph_output: | |||
| continue # avoid unnecessary checking | |||
| for base_out in nd_struct.outputs_manager.outputs: | |||
| if base_out.onnx_edge_name in self._global_context.onnx_graph_info.get('graph_outputs'): | |||
| MatcherHelper.register_outputs_to_main_model(base_out.onnx_edge_name, nd_struct) | |||
| continue | |||
| out_user_onnx_names = base_out.onnx_user | |||
| for out_user_onnx_name in out_user_onnx_names: | |||
| out_user_struct = \ | |||
| self._global_context.onnx_node_name_to_node_struct_map.get(out_user_onnx_name) | |||
| if out_user_struct is None: | |||
| raise ValueError(f"The Matcher detected an output has unknown provider for the edge "\ | |||
| f"{base_out.onnx_edge_name}") | |||
| public_parent = MatcherHelper.get_public_parent_module(nd_struct, out_user_struct) | |||
| nd_parent = nd_struct.parent_module_struct | |||
| # Recursively register outputs to parents until the public module | |||
| while public_parent.identifier != nd_parent.identifier: | |||
| nd_parent.add_outputs_edge(base_out.onnx_edge_name) | |||
| nd_parent = nd_parent.parent_module_struct | |||
| @@ -34,7 +34,6 @@ class ModuleStruct: | |||
| init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. | |||
| parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. | |||
| """ | |||
| def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): | |||
| """Init. a module by NodeStructs.""" | |||
| self.pattern_id = -1 # pattern num, -1 as Main module | |||
| @@ -74,6 +73,20 @@ class ModuleStruct: | |||
| # Define outputs manager, note this will be assigned later by Generator. | |||
| self.outputs_manager = None | |||
| self._global_context = GlobalContext() | |||
| # Define a dict to store the reference for quick searching | |||
| self.rapid_reference = dict() | |||
| # new vars for matcher | |||
| self.inputs_register = dict() # reg by sub | |||
| self.outputs_register = OrderedDict() # reg by sub | |||
| self.internal_outputs_collection = dict() # reg by sub | |||
| # new vars for shared weights | |||
| self.shared_weights_collection = dict() # reg by sub | |||
| self.shared_weights_counter = 0 # updated by sub | |||
| if init_as_parent and (parent_base is not None): | |||
| self.reset_as_parent_passed_in(parent_base) | |||
| else: | |||
| @@ -293,8 +306,26 @@ class ModuleStruct: | |||
| ret.sort(key=lambda x: x[0]) | |||
| return ret | |||
| def _code_line_init_statement_shared_weights_args(self): | |||
| """Generate the args for shared weights where calling this module.""" | |||
| args_list = list() | |||
| for passthrough_w_onnx_name, passthrough_w_var_name in self.shared_weights_collection.items(): | |||
| passthrough_w_var_name_in_parent = \ | |||
| self.parent_module_struct.shared_weights_collection.get(passthrough_w_onnx_name) | |||
| if self.parent_module_struct.identifier == []: # now only consider declaration in main model | |||
| args_list.append(f"{passthrough_w_var_name}=self.{passthrough_w_var_name_in_parent}") | |||
| else: | |||
| args_list.append(f"{passthrough_w_var_name}={passthrough_w_var_name_in_parent}") | |||
| return args_list | |||
| def _code_line_init_generate_shared_w_declaration_for_repeated(self): | |||
| """Force to repeat sub nodes init code line for fulfillment of shared weight declaration in main model.""" | |||
| for _, nd_struct in self._node_structs: | |||
| nd_struct.code_line_in_init() | |||
| def code_line_in_init(self): | |||
| """Initialization line of code in module init block.""" | |||
| self._code_line_init_generate_shared_w_declaration_for_repeated() | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| # Load args in init statement. | |||
| @@ -308,26 +339,40 @@ class ModuleStruct: | |||
| args_list += self._args_translator.formal_args_to_str_list # load from formal args | |||
| else: | |||
| args_list += self._fragment.actual_args | |||
| args_list += self._code_line_init_statement_shared_weights_args() | |||
| right = f"{self.class_name}({', '.join(args_list)})" | |||
| return left, right | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block.""" | |||
| # check number of outputs this module has | |||
| opt_var_name_in_module = list(self.external_successor_local_returns_map.values()) | |||
| num_output = len(set(opt_var_name_in_module)) | |||
| outputs_edges = list(self.outputs_register.keys()) | |||
| num_output = len(outputs_edges) | |||
| # Allocate opt_var_name | |||
| if num_output == 1: # single output | |||
| left = f"{self.ms_opt_var_name}" | |||
| left = [f"{self.ms_opt_var_name}"] | |||
| else: | |||
| left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)] | |||
| if inputs is None and self.matched_inputs: | |||
| inputs = self.matched_inputs | |||
| inputs = [] | |||
| # Update self's outputs mgr | |||
| for idx, edge in enumerate(outputs_edges): | |||
| base_out = self.outputs_manager.get_base_out(edge) | |||
| if base_out.opt_var_name is None: | |||
| print(f"ModuleStruct {self.identifier} has an output {base_out.onnx_edge_name} not has opt_var_name") | |||
| base_out.opt_var_name = left[idx] | |||
| self.parent_module_struct.internal_outputs_collection[base_out.onnx_edge_name] = base_out.opt_var_name | |||
| # Take inputs from parent & previous | |||
| for input_edge in self.inputs_register: | |||
| if input_edge in self.parent_module_struct.inputs_register: | |||
| inputs.append(self.parent_module_struct.inputs_register.get(input_edge)) | |||
| elif input_edge in self.parent_module_struct.internal_outputs_collection: | |||
| inputs.append(self.parent_module_struct.internal_outputs_collection.get(input_edge)) | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return left, right | |||
| left = ", ".join(left) | |||
| return (left, right) | |||
| @property | |||
| def node_structs(self): | |||
| @@ -377,23 +422,36 @@ class ModuleStruct: | |||
| @property | |||
| def onnx_names_from_nodes(self) -> list: | |||
| """Return all nodes onnx names in this module.""" | |||
| ret = [] | |||
| for (_, node) in self.node_structs: | |||
| ret.append(node.onnx_name) | |||
| if self._global_context.build_struct_finished and "_onnx_names_from_nodes" in self.rapid_reference: | |||
| return self.rapid_reference["_onnx_names_from_nodes"] | |||
| ret = [node.onnx_name for (_, node) in self.node_structs] | |||
| if self._global_context.build_struct_finished: | |||
| self.rapid_reference["_onnx_names_from_nodes"] = ret | |||
| return ret | |||
| @property | |||
| def onnx_names_from_submodules(self) -> list: | |||
| """Return all nodes onnx names in submodules of this module.""" | |||
| if self._global_context.build_struct_finished and "_onnx_names_from_submodules" in self.rapid_reference: | |||
| return self.rapid_reference["_onnx_names_from_submodules"] | |||
| ret = [] | |||
| for md_struct in self.module_structs: | |||
| ret += md_struct.onnx_names | |||
| if self._global_context.build_struct_finished: | |||
| self.rapid_reference["_onnx_names_from_submodules"] = ret | |||
| return ret | |||
| @property | |||
| def onnx_names(self) -> list: | |||
| """Return all nodes' onnx names which contained by this module.""" | |||
| return self.onnx_names_from_nodes + self.onnx_names_from_submodules | |||
| if self._global_context.build_struct_finished and "_onnx_names" in self.rapid_reference: | |||
| return self.rapid_reference["_onnx_names"] | |||
| ret = self.onnx_names_from_nodes + self.onnx_names_from_submodules | |||
| if self._global_context.build_struct_finished: | |||
| self.rapid_reference["_onnx_names"] = ret | |||
| return ret | |||
| @property | |||
| def external_precursor_nodes_names(self) -> list: | |||
| @@ -434,8 +492,8 @@ class ModuleStruct: | |||
| """Return the class name for generating code of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||
| class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) | |||
| if self._global_context.known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||
| class_name = self._global_context.known_module_name.get("Module{}".format(self.pattern_id)) | |||
| else: | |||
| class_name = "Module{}".format(self.pattern_id) | |||
| return class_name | |||
| @@ -676,3 +734,26 @@ class ModuleStruct: | |||
| submodule_opt_var_name = md_struct.ms_opt_var_name | |||
| for (submodule_ext_succ, _, ith_output) in submodule_returns: | |||
| self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | |||
| # The following funcs are designated to be invoked by matcher. | |||
| def add_inputs_edge(self, edge_name: str): | |||
| construct_header_length = len(self.inputs_register.values()) | |||
| default_x_str = "x" | |||
| if not edge_name in self.inputs_register: | |||
| self.inputs_register[edge_name] = "".join([default_x_str, str(construct_header_length-1)]) \ | |||
| if construct_header_length > 0 else default_x_str | |||
| def add_outputs_edge(self, edge_name: str): | |||
| if edge_name in self.outputs_register: | |||
| return # to be filled during code generation, should from sub's opt_var_name | |||
| self.outputs_register[edge_name] = "<placeholder>" | |||
| def fill_outputs_edge(self, edge_name: str, opt_var_name: str): | |||
| # FILL the outputs edge once you got a opt_var_name of corresponding node!!! | |||
| if not edge_name in self.outputs_register: | |||
| raise ValueError(f"ModuleStruct {self.identifier} does not have edge "\ | |||
| f"{edge_name} and unable to fill its output var name.") | |||
| if self.outputs_register[edge_name] != "<placeholder>": | |||
| raise ValueError(f"The edge has been already filled as {self.outputs_register[edge_name]}" \ | |||
| f" instead of your {opt_var_name}") | |||
| self.outputs_register[edge_name] = opt_var_name | |||
| @@ -35,7 +35,6 @@ class NodeStruct: | |||
| You can pass as many args as possible and the Node Struct will update | |||
| by arguments order. | |||
| """ | |||
| def __init__(self, args): | |||
| # define attributes here | |||
| self.global_context_mgr = GlobalContext() | |||
| @@ -43,6 +42,7 @@ class NodeStruct: | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self._parent_module_struct = None | |||
| self._global_context = GlobalContext() | |||
| self.topo_idx = None | |||
| self.onnx_name = None | |||
| self.graph_node_ref = None | |||
| @@ -75,7 +75,7 @@ class NodeStruct: | |||
| """Get the original topological index in the onnx graph.""" | |||
| ori_name = self._fragment.metadata.get('source') | |||
| self.onnx_name = ori_name | |||
| return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name) | |||
| return self._global_context.onnx_node_name_to_topo_idx.get(ori_name) | |||
| def update_var_name(self, idx=None): | |||
| """ | |||
| @@ -114,7 +114,7 @@ class NodeStruct: | |||
| self._fragment = FragmentHandler(frag) | |||
| if self.ms_op: | |||
| idx = GlobalContext().latest_node_struct_count | |||
| idx = self._global_context.latest_node_struct_count | |||
| self.update_var_name(idx=idx) | |||
| def _set_scope_from_identifier(self): | |||
| @@ -168,7 +168,7 @@ class NodeStruct: | |||
| self._identifier = s | |||
| self._set_scope_from_identifier() | |||
| self.topo_idx = self.ori_topo_idx() | |||
| GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||
| self._global_context.onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||
| @property | |||
| def fragment(self): | |||
| @@ -198,7 +198,7 @@ class NodeStruct: | |||
| @property | |||
| def onnx_node(self): | |||
| """Return the original onnx node reference.""" | |||
| return GlobalContext().onnx_nodes_collection.get(self.onnx_name) | |||
| return self._global_context.onnx_nodes_collection.get(self.onnx_name) | |||
| @property | |||
| def ms_op(self): | |||
| @@ -241,7 +241,7 @@ class NodeStruct: | |||
| ret = [] | |||
| precursor_nodes_names = self.precursor_nodes_names | |||
| for pre_node_name in precursor_nodes_names: | |||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @@ -255,7 +255,7 @@ class NodeStruct: | |||
| """Return the node struct instances of successor nodes.""" | |||
| ret = [] | |||
| for pre_node_name in self.successor_nodes_names: | |||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @@ -278,54 +278,110 @@ class NodeStruct: | |||
| """Return the outputs var(s) in construct statement.""" | |||
| return self.fragment.fragment.outputs() | |||
| @property | |||
| def inputs_edges_names(self): | |||
| """Return the inputs edges of this node.""" | |||
| # Consider moving this process to metadata. | |||
| ret = [] | |||
| for edge in self.fragment.metadata.get('inputs'): | |||
| if not self._global_context.get_onnx_tensor(edge): | |||
| ret.append(edge) | |||
| return ret | |||
| @property | |||
| def shared_weights(self): | |||
| """Return the shared weights in this node.""" | |||
| shared_weight_names = [] | |||
| for shared_weight_name, repeated_node_list in self._global_context.repeated_weights.items(): | |||
| if self.onnx_name in repeated_node_list: | |||
| shared_weight_names.append(shared_weight_name) | |||
| return shared_weight_names | |||
| # Code Generation funcs below | |||
| def _get_shared_weight_var_names_from_parent(self, onnx_name=None): | |||
| """ | |||
| Get shared weight var name in the parent module. | |||
| Args: | |||
| onnx_name (str): The onnx name of this weight. Default None. | |||
| Returns: | |||
| [List, str], a list of all shared weights the node has or the specific name provided. | |||
| """ | |||
| if onnx_name is None: | |||
| shared_weights_var_name_in_module = [] | |||
| for shared_w in self.shared_weights: | |||
| for passthrough_w, passthrough_w_var_name in \ | |||
| self._parent_module_struct.shared_weights_collection.items(): | |||
| if shared_w == passthrough_w: | |||
| shared_weights_var_name_in_module.append(passthrough_w_var_name) | |||
| return shared_weights_var_name_in_module | |||
| if isinstance(onnx_name, str): | |||
| return self._parent_module_struct.shared_weights_collection.get(onnx_name) | |||
| return [] | |||
| def code_line_in_init(self): | |||
| """Initialization line of code in module init block.""" | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| if self._args_translator is not None: | |||
| self.fragment.default_var['args'] = {**self._args_translator.actual_args, | |||
| **self._args_translator.formal_args} | |||
| args_list += self._args_translator.actual_args_to_str_list | |||
| args_list += self._args_translator.formal_args_to_str_list | |||
| else: | |||
| actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args']) | |||
| args_list += actual_args_str | |||
| if not self._fragment.converted: | |||
| args_list.append('='.join(["input_shape", str(self._fragment.input_shape)])) | |||
| args_list.append('='.join(["output_shape", str(self._fragment.output_shape)])) | |||
| right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})" | |||
| else: | |||
| right = f"{self.ms_op}({', '.join(args_list)})" | |||
| return left, right | |||
| # create a parameter for shared weight scenario | |||
| trainable_params = self.fragment.default_var.get("trainable_params") | |||
| if trainable_params and self.fragment.default_var.get("parameters"): | |||
| # if trainable params and the mappers accept the param declaration rewritten. | |||
| for trainable_param_postfix, data_dict in trainable_params.items(): | |||
| onnx_name = data_dict.get('onnx_name') | |||
| nparray = data_dict.get('data') | |||
| try: | |||
| shape = nparray.shape | |||
| dtype = nparray.dtype | |||
| except Exception: | |||
| raise ValueError("Parameters has inconsistent data type.") | |||
| # set declare statement | |||
| declare_statement = self.fragment.fragment.create_parameter(shape, dtype) | |||
| if onnx_name not in self._global_context.repeated_weights.keys(): | |||
| # if the weight is not a shared weight, set to actual declaration. | |||
| if not self.fragment.default_var["parameters"].get(trainable_param_postfix): | |||
| self.fragment.default_var["parameters"][trainable_param_postfix] = declare_statement | |||
| continue # not a shared weight, skip the rest | |||
| if onnx_name in self._global_context.repeated_weights_declaration.keys(): | |||
| continue # already declared, skip | |||
| self._global_context.repeated_weights_declaration[onnx_name] = declare_statement | |||
| # set template to mapper parameter rewritten. | |||
| shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent(onnx_name=onnx_name) | |||
| # add self for node node under public parent module | |||
| if self.parent_module_struct.identifier == []: | |||
| #now only consider declaration in the main model | |||
| shared_w_var_in_parent = f"self.{shared_w_var_in_parent}" | |||
| self.fragment.default_var["parameters"][trainable_param_postfix] = shared_w_var_in_parent | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block. """ | |||
| left = self.ms_opt_var_name | |||
| if not self.matched_inputs and inputs is None: | |||
| raise ValueError("Unable to generate the code construct statement due to empty inputs.") | |||
| inputs = [] | |||
| if self.matched_inputs: | |||
| inputs = self.matched_inputs | |||
| # Bind current node opt_var_name & register to parent | |||
| self.outputs_manager.bind_opt_var_names(self.fragment.fragment) | |||
| for base_out in self.outputs_manager.outputs: | |||
| opt_var = base_out.opt_var_name | |||
| self.parent_module_struct.internal_outputs_collection[base_out.onnx_edge_name] = opt_var | |||
| # Check original onnx node's input to ensure double inputs are not ignored | |||
| original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name) | |||
| new_inputs = [] | |||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||
| occurrence = original_inputs.count(prec_node) | |||
| for _ in range(occurrence): | |||
| new_inputs.append(inputs[idx]) | |||
| inputs = new_inputs | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| # Take inputs from parents module | |||
| for input_edge in self.inputs_edges_names: | |||
| if input_edge in self.parent_module_struct.inputs_register: | |||
| inputs.append(self.parent_module_struct.inputs_register.get(input_edge)) | |||
| elif input_edge in self.parent_module_struct.internal_outputs_collection: | |||
| inputs.append(self.parent_module_struct.internal_outputs_collection.get(input_edge)) | |||
| self.fragment.default_var['inputs'] = inputs | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return left, right | |||
| return left | |||
| def add_extra_tensor(self): | |||
| """ Add extra tensor.""" | |||
| @@ -360,12 +416,12 @@ class NodeStruct: | |||
| Args: | |||
| name (str): Can accept both node identifier or original onnx node name. | |||
| """ | |||
| target_nd_struct = GlobalContext().node_struct_collections.get(name) \ | |||
| or GlobalContext().onnx_node_name_to_node_struct_map.get(name) | |||
| target_nd_struct = self._global_context.node_struct_collections.get(name) \ | |||
| or self._global_context.onnx_node_name_to_node_struct_map.get(name) | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| return False | |||
| if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')): | |||
| if target_nd_struct is None and (name in self._global_context.onnx_graph_info.get('graph_inputs')): | |||
| return False | |||
| if target_nd_struct is None: | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Module rocessing for shared weights.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct | |||
| class SharedWeightHelper: | |||
| """Helper function to process shared weights.""" | |||
| @staticmethod | |||
| def check_node_has_shared_weight(node: NodeStruct): | |||
| """ | |||
| Check the node has shared weight and return all of them. | |||
| Args: | |||
| node (NodeStruct): NodeStruct instance. | |||
| Returns: | |||
| list, a list of shared weight onnx names | |||
| """ | |||
| shared_weight_names = [] | |||
| for shared_weight_name, repeated_node_list in GlobalContext().repeated_weights.items(): | |||
| if node.onnx_name in repeated_node_list: | |||
| shared_weight_names.append(shared_weight_name) | |||
| return shared_weight_names | |||
| @staticmethod | |||
| def add_shared_weight_to_parent_module(shared_weight_name: str, module_to_be_registered: ModuleStruct): | |||
| """Register the shared weight name to module and assign a local var name for it.""" | |||
| default_weight_name = f"passthrough_w_{module_to_be_registered.shared_weights_counter}" | |||
| if shared_weight_name not in module_to_be_registered.shared_weights_collection: | |||
| module_to_be_registered.shared_weights_collection[shared_weight_name] = default_weight_name | |||
| module_to_be_registered.shared_weights_counter += 1 | |||
| @staticmethod | |||
| def register_shared_weight_to_public_parent(node: NodeStruct, shared_weight_name: str, pub_module_identifier): | |||
| """ | |||
| Register shared weight from bottom to top until its public module. | |||
| Note: | |||
| Now we always consider the public module is main model, since looking for public module among multiple | |||
| nodes consume long time. | |||
| Args:where the shared weight to be used. | |||
| node (NodeStruct): The NodeStruct instance which has the shared weight. | |||
| share_weight_name (str): The onnx name of the shared weights. | |||
| pub_module_identifier (list): The identifier of the public module the shared weight in. | |||
| """ | |||
| parent_module = node.parent_module_struct | |||
| exit_flag = False | |||
| while True: | |||
| if parent_module.identifier == pub_module_identifier: | |||
| exit_flag = True | |||
| SharedWeightHelper.add_shared_weight_to_parent_module(shared_weight_name, parent_module) | |||
| parent_module = parent_module.parent_module_struct | |||
| if exit_flag: | |||
| break | |||
| if parent_module is None: | |||
| break | |||
| @staticmethod | |||
| def add_shared_weights_in_init_statement(md_struct: ModuleStruct, module_def_args: list): | |||
| """add shared weights to module init statement.""" | |||
| if md_struct.shared_weights_collection: | |||
| return module_def_args + list(md_struct.shared_weights_collection.values()) | |||
| return module_def_args | |||
| @staticmethod | |||
| def public_module_shared_weight_statement_generation(public_module: ModuleStruct): | |||
| """Return the statement of declaration of shared weights in its public module.""" | |||
| statements = [] | |||
| for passthrough_w_onnx_name, passthrough_w_var_name in public_module.shared_weights_collection.items(): | |||
| parameter_statement = GlobalContext().repeated_weights_declaration.get(passthrough_w_onnx_name) | |||
| declare_statement = f"self.{passthrough_w_var_name} = {parameter_statement}" | |||
| statements.append(declare_statement) | |||
| return statements | |||
| @@ -181,9 +181,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| def _find_val_by_index(loc_index, weights_list, default_value=None): | |||
| def _find_val_by_index(loc_index, weights_list, default_val=None): | |||
| """Find value by location index of weights_list.""" | |||
| result = default_value | |||
| result = default_val | |||
| if loc_index < 0: | |||
| return weights_list[loc_index].value | |||
| @@ -196,7 +196,6 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| @staticmethod | |||
| def _find_location_by_index(loc_index, weights_list): | |||
| """Find weight location in inputs of Node.""" | |||
| result = -1 | |||
| if loc_index < 0: | |||
| return weights_list[loc_index].location | |||
| @@ -206,3 +205,16 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| result = weight.location | |||
| break | |||
| return result | |||
| @staticmethod | |||
| def _find_onnx_name_by_index(loc_index, weights_list): | |||
| """Find weight onnx name in inputs of Node.""" | |||
| result = -1 | |||
| if loc_index < 0: | |||
| return weights_list[loc_index].name | |||
| for idx, weight in enumerate(weights_list): | |||
| if idx == loc_index: | |||
| result = weight.name | |||
| break | |||
| return result | |||
| @@ -33,7 +33,8 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs['weights'] | |||
| weight = MatMulMapper._find_val_by_index(0, weights) | |||
| return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}} | |||
| onnx_name = MatMulMapper._find_onnx_name_by_index(0, weights) | |||
| return {'w': {'data': weight, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}} | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| @@ -48,15 +49,11 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = MatMulMapper._find_val_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| args["weight_shape"] = tensor.shape | |||
| args["weight_dtype"] = tensor.dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \ | |||
| f"name=None)" | |||
| # Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing. | |||
| variable_slot_param_name = f"{variable_slot}/w" | |||
| init_tensor = f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w)" | |||
| @@ -75,7 +72,10 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params, | |||
| ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value: { | |||
| "w": "" | |||
| } | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| @@ -35,8 +35,9 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = AddMapper._find_val_by_index(0, weights) | |||
| onnx_name = AddMapper._find_onnx_name_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}} | |||
| return dict() | |||
| @staticmethod | |||
| @@ -54,7 +55,6 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| tensor = AddMapper._find_val_by_index(0, weights) | |||
| bias_shape = tensor.shape | |||
| bias_dtype = tensor.dtype | |||
| bias_location = AddMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| @@ -64,11 +64,10 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| inputs_in_construct.insert(bias_location, f"self.{{{variable_slot}}}_bias") | |||
| if bias_shape: | |||
| args["bias_shape"] = bias_shape | |||
| args["bias_dtype"] = bias_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ | |||
| f"name=None)" | |||
| # Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing. | |||
| variable_slot_param_name = f"{variable_slot}/bias" | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = {{{variable_slot_param_name}}}" | |||
| else: | |||
| args["bias_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_bias = {{bias_value}}" | |||
| @@ -93,6 +92,10 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| if bias_shape: | |||
| exchange_msg[variable_slot][ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] = { | |||
| "bias": "" | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -36,8 +36,9 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| weights = kwargs.get('weights', list()) | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| onnx_name = MulMapper._find_onnx_name_by_index(0, weights) | |||
| if isinstance(tensor, np.ndarray) and tensor.shape: | |||
| return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value}} | |||
| return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}} | |||
| return dict() | |||
| @staticmethod | |||
| @@ -47,13 +48,12 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| trainable_params = kwargs.get('trainable_params', dict()) | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| w_shape = tensor.shape | |||
| w_dtype = tensor.dtype | |||
| w_location = MulMapper._find_location_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| @@ -63,11 +63,9 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| inputs_in_construct.insert(w_location, f"self.{{{variable_slot}}}_w") | |||
| if w_shape: | |||
| args["w_shape"] = w_shape | |||
| args["w_dtype"] = w_dtype | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{w_shape}}).astype(np.{{w_dtype}})), " \ | |||
| f"name=None)" | |||
| # Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing. | |||
| variable_slot_param_name = f"{variable_slot}/w" | |||
| init_tensor = f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}" | |||
| else: | |||
| args["w_value"] = tensor.tolist() | |||
| init_tensor = f"self.{{{variable_slot}}}_w = {{w_value}}" | |||
| @@ -90,6 +88,10 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| if w_shape: | |||
| exchange_msg[variable_slot][ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] = { | |||
| "w": "" | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -421,11 +421,16 @@ class OnnxDataLoader: | |||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||
| for k in self.repeated_weight: | |||
| if not self.tensors_dict.get(k).to_array().shape: | |||
| # scalar does not have shape info | |||
| continue | |||
| self.repeated_weight[k] = record_tensors[k][:] | |||
| self._global_context.onnx_nodes_collection = self._nodes_dict | |||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | |||
| self._global_context.repeated_weights = self.repeated_weight | |||
| # now only process shared weights for multi-inputs models | |||
| if len(self.input_nodes) > 1: | |||
| self._global_context.repeated_weights = self.repeated_weight | |||
| def _parse_tensors(self): | |||
| """Parse each onnx tensors in the model.""" | |||
| @@ -500,10 +505,14 @@ class OnnxDataLoader: | |||
| # Parse ONNX Graph level info | |||
| self._parse_graph() | |||
| # 1. parse all nodes | |||
| # 1. parse all tensors | |||
| self._parse_tensors() | |||
| # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info | |||
| # to process the node weight sharing. | |||
| self._parse_nodes() | |||
| # 2. parse value info (incl. node output shape) | |||
| # 3. parse value info (incl. node output shape) | |||
| if self._is_infer_shape: | |||
| try: | |||
| self._infer_model() | |||
| @@ -514,9 +523,6 @@ class OnnxDataLoader: | |||
| log.exception(e) | |||
| raise e | |||
| # 3. parse all tensors | |||
| self._parse_tensors() | |||
| # 4. Optimize graph to eliminate some nodes. | |||
| self._find_nodes_to_be_eliminated() | |||