diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py index e0a228bf..108d543e 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -27,8 +27,9 @@ class Fragment(abc.ABC): Args: operation (str): Operation name in MindSpore. actual_args (dict): Actual arg values. + input_shape (tuple): The input shape of the node. + output_shape (tuple): The output shape of the node. settings (namedTuple): Code generation setting. - """ def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): @@ -46,6 +47,7 @@ class Fragment(abc.ABC): @property def code_setting(self): + """Code Setting getter.""" return self._code_setting @property @@ -152,10 +154,12 @@ class Fragment(abc.ABC): @property def input_shape(self): + """Return the input shape.""" return self._input_shape @property def output_shape(self): + """Return the output shape.""" return self._output_shape @@ -196,6 +200,7 @@ class CodeFragment(Fragment): @property def trainable_params(self): + """Return the trainable parameters.""" return self._trainable_params diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index 836c911c..64a723d8 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================== """Define GlobalContext class to save required resources during whole conversion procedure.""" from collections import OrderedDict +from .outputs import OutputStorage class Singleton(type): @@ -45,6 +46,7 @@ class GlobalContext(metaclass=Singleton): self.onnx_node_name_to_topo_idx = dict() self.onnx_node_inputs = dict() self._onnx_tensors_collection = dict() + self.onnx_graph_info = dict() # Define data stored from generator # Key as Node Identifier @@ -72,6 +74,8 @@ class GlobalContext(metaclass=Singleton): # key is target node (which use this opt), value is opt_var_name self.extra_input_dict = dict() + self.outputs_storage = OutputStorage() + def get_onnx_node_from_identifier(self, identifier): """Return an OnnxUtils defined node by its identifier.""" onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) diff --git a/mindinsight/mindconverter/graph_based_converter/common/outputs.py b/mindinsight/mindconverter/graph_based_converter/common/outputs.py new file mode 100644 index 00000000..13d31861 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/outputs.py @@ -0,0 +1,200 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define basic classes for generator use.""" +import abc +import copy +from typing import Union, Iterable + +class BaseOutput: + """ + Define the class of output providing a universal nodes' and modules' output data collection. + + Args: + output_mapping (tuple[tuple]): The mapping of outputs from onnx to mindspore. + """ + def __init__(self, output_mapping) -> None: + super(BaseOutput).__init__() + self.idx_in_ms_provider = output_mapping[0] + self.idx_in_onnx_provider = output_mapping[1] + + # For multi users, key as user and value as index + self.idx_in_ms_user = dict() + self.idx_in_onnx_user = dict() + + # The following attributes to be set by who referenced this object. + self.onnx_edge_name = None + self.to_external = False + + @property + def ms_user(self): + """Return the output's user in the MindSpore.""" + return self.idx_in_ms_user.keys() + + @property + def onnx_user(self): + """Return the output's user in the ONNX.""" + return self.idx_in_onnx_user.keys() + + def deepcopy(self): + """Return a deepcopy of self instance.""" + return copy.deepcopy(self) + + +class BaseOutputManager(abc.ABC): + """ + Base Output Manager class. + + Args: + output_mappings (list): A list of output mapping. + """ + def __init__(self, output_mappings): + if isinstance(self.__class__, ModuleOutputManager): + return + self._base_output_list = list() + + # init base output obj + for mapping in output_mappings: + obj = BaseOutput(mapping) + self._base_output_list.append(obj) + + @property + def outputs(self): + """Return the list of BaseOutput in this manager.""" + return self._base_output_list + + @outputs.setter + def outputs(self, val: list): + """Set the list of BaseOutput in this manager.""" + for v in val: + if not isinstance(v, BaseOutput): + raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.") + self._base_output_list = val + + @abc.abstractmethod + def deepcopy(self): + """Return the deepcopy of this instance.""" + cls = self.__class__ + result = cls.__new__(cls) + result.outputs = list() + for out in self._base_output_list: + result.outputs.append(out.deepcopy()) + return result + + +class NodeOutputManager(BaseOutputManager): + """ + Node Output Manager class. + + Args: + identifier (str): The identifier of the node. + output_mappings (list): A list of the output mapping. + """ + def __init__(self, identifier, output_mappings=None) -> None: + super(NodeOutputManager, self).__init__(output_mappings) + self.identifier = identifier + + def deepcopy(self): + new_mgr = super().deepcopy() + new_mgr.identifier = self.identifier + return new_mgr + + +class ModuleOutputManager(BaseOutputManager): + """ + Module Output Manager class. + + Args: + identifier (str): The identifier of the module. + output_mappings (list): a list of output mapping + """ + def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None: + super(ModuleOutputManager, self).__init__(None) + self.identifier = identifier + self._return_list_counter = 0 + self._base_output_list = list() + if isinstance(base_out, BaseOutput): + self._base_output_list.append(base_out) + else: + self._base_output_list += base_out + + @property + def return_num(self): + """Return the number of outputs to be returned.""" + return self._return_list_counter + + @return_num.setter + def return_num(self, num: int): + """Set the number of outputs to be returned.""" + self._return_list_counter = num + + def deepcopy(self): + """Return a deepcopy of current instance.""" + new_mgr = super().deepcopy() + new_mgr.identifier = self.identifier + new_mgr.return_num = self._return_list_counter + return new_mgr + + +class OutputStorage: + """A class saves all outputs.""" + def __init__(self): + self._base_output_edge_to_instance = dict() + self._base_output_edge_to_onnx_node_name = dict() + self._base_output_edge_to_ms_identifier = dict() + + @property + def outputs_collections(self) -> dict: + """Return the dict of edge name to output instance.""" + return self._base_output_edge_to_instance + + def onnx_name(self, output_edge) -> str: + """Return the dict of edge name to onnx node name.""" + return self._base_output_edge_to_onnx_node_name.get(output_edge) + + def node_identifier(self, output_edge): + """Return the dict of edge name to node identifier.""" + return self._base_output_edge_to_ms_identifier.get(output_edge) + + def add_output(self, out: BaseOutput) -> str: + """ + Add a BaseOutput instance to the storage. + + Args: + out (BaseOutput): The BaseOutput instance. + """ + if out.onnx_edge_name: + self._base_output_edge_to_instance[out.onnx_edge_name] = out + else: + raise ValueError("Unable to add a BaseOutput instance with unknown ONNX edge.") + + def add_onnx_node_name(self, edge: str, onnx_node_name: str): + """ + Add the onnx node name with the edge name. + + Args: + edge (str): The edge name of this output. + onnx_node_name (str): The onnx node which has the edge. + """ + self._base_output_edge_to_onnx_node_name[edge] = onnx_node_name + + def add_ms_identifier(self, edge: str, ms_identifier: str): + """ + Add the node identifier with the edge name. + + Args: + edge (str): The edge name of this output. + ms_identifier (str): The identifier of the node which has the edge. + """ + self._base_output_edge_to_ms_identifier[edge] = ms_identifier diff --git a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py index 517cc6b2..dafdfc96 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py @@ -18,9 +18,10 @@ __all__ = ["batch_add_nodes"] import re import copy -from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment -from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords from .generator import Generator, CodeStruct +from ..common.code_fragment import CodeFragment, NewFragment +from ..common.outputs import NodeOutputManager +from ..constant import ExchangeMessageKeywords def _tf_model_node_name_reformat(node, node_name): @@ -56,6 +57,7 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: """ generator_inst = Generator() + external_inputs = graph_obj.user_provided_input_nodes for node_name in graph_obj.nodes_in_topological_order: node_inst = graph_obj.get_node(node_name) node_input = graph_obj.get_input_shape(node_name) @@ -66,16 +68,11 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: node_name = node_name_with_scope node_inst.add_input_and_output_shape(node_input, node_output) - op_name, params, settings, weights = _convert_params(node_inst, mapper) - generator_inst.add_node( - node_name, - node_instance=node_inst, - node_fragment=CodeFragment(op_name, params, - settings, - node_inst.input_shape, - node_inst.output_shape, - weights) - ) + code_template, exchange_msg, outputs_lst, outputs_mapping = _convert_params(node_inst, mapper, external_inputs) + outputs_mapping = NodeOutputManager(node_name, output_mappings=outputs_mapping) + fragment = NewFragment(data_entity=exchange_msg, code_template=code_template, + outputs=outputs_lst, outputs_mapping=outputs_mapping) + generator_inst.add_node(node_name, node_instance=node_inst, node_fragment=fragment) return generator_inst @@ -105,13 +102,14 @@ def _supply_graph_info(node, external_inputs): } -def _convert_params(node, mapper): +def _convert_params(node, mapper, external_inputs): """ Call mapper to convert node's params from ONNX to MindSpore. Args: node (GraphNode): Our defined GraphNode instance. mapper (Mapper): The mapper instance which indicating conversion method. + external_inputs (list[str]): External inputs provided by users. Returns: tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, @@ -121,18 +119,11 @@ def _convert_params(node, mapper): params.update({"input_shape": node.input_shape, "output_shape": node.output_shape}) - op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, - params=params, - weights=node.weight) - if "input_shape" in ms_params: - ms_params.pop("input_shape") - if "output_shape" in ms_params: - ms_params.pop("output_shape") - - if op_in_ms: - return op_in_ms, ms_params, ms_settings, weights - - return node.op_name, node.node_params, dict(), dict() + code_template, exchange_msg, outputs_lst, outputs_order_mapping = mapper.convert(op_name=node.op_name, + params=params, + weights=node.weight) + exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs) + return code_template, exchange_msg, outputs_lst, outputs_order_mapping def _combine_external_inputs_with_precursor_nodes(node, external_inputs): diff --git a/mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py b/mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py new file mode 100644 index 00000000..1d6c6a66 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Miscellaneous Fragment related classes and functions. """ + +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment + + +class FragmentHandler: + """ + Define a handler to process the infomation contained by Fragment. + + Args: + fragment (NewFragment): The refactored fragment class. + """ + def __init__(self, fragment: NewFragment): + self._fragment = fragment + # set the var in the fragment to be load and save. + self._target_var = "var_0" + + @property + def target_var(self): + """Return the target var name the handler currently set to be read.""" + return self._target_var + + @target_var.setter + def target_var(self, target): + """Set the target var the handler will read.""" + if not target in self.exchange_msg.keys(): + raise ValueError(f"Unable to set target var {target} where fragment does not have it.") + self._target_var = target + + @property + def fragment(self): + """Return the fragment instance the handler currently processed.""" + return self._fragment + + @property + def converted(self): + """Return the status of the op successfully converted.""" + return bool(self._fragment.exchange_msg) + + # The following section is intended for Fragment exchange message. + @property + def exchange_msg(self): + """Return the exchange message dictionary the fragment contains.""" + return self._fragment.exchange_msg + + @property + def var(self): + """Return the var dictionary the handler currently set to be processed.""" + try: + return self.exchange_msg.get(self.target_var) + except AttributeError: + return None + + @property + def default_var(self): + """Return the default var dictionary the handler processed.""" + try: + return self.exchange_msg.get("var_0") + except AttributeError: + return None + + # For metadata + @property + def metadata(self): + """Return the metadata of the onnx node info dictionary.""" + return self._fragment.exchange_msg.get("metadata") + + @property + def input_shape(self): + """Return the input shape of this node.""" + return self.metadata.get('inputs_shape') + + @property + def output_shape(self): + """Return the output shape of this node.""" + return self.metadata.get('outputs_shape') + + # For outputs + @property + def outputs_manager(self): + """Return the outputs manager of this node.""" + return self._fragment.outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index 8c28a128..42fdce76 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -23,6 +23,7 @@ from .node_struct import NodeStruct from .module_struct import ModuleStruct from .args_translator import ArgsTranslationHelper from ..common.global_context import GlobalContext +from ..common.outputs import BaseOutput, ModuleOutputManager from ...common.exceptions import GeneratorError from ..common.name_mgr import GlobalVarNameMgr from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module @@ -39,21 +40,10 @@ class CodeStruct: def __init__(self, struct, repeated_submodules=None): """Initialize the CodeStruct.""" - self.output_order = None # output order - self.input = None # opt_var_name for prev. node - self.extra_input = list() # extra_input(s) at construct method args - self.output = None # opt_var_name for next node - self.extra_output = list() # extra_output(s) - self.extra_comment = None # comments for this code line / block. self.code_line_list = list() # list of code line, a item is a line. self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module - self.formal_args_collections = None - - if isinstance(struct, NodeStruct): - self.output_order = struct.topo_idx if isinstance(struct, ModuleStruct): - self.output_order = struct.head_nd_struct_index self._generate_from_module_struct(struct, repeated_submodules) def _add_line(self, s): @@ -102,13 +92,15 @@ class CodeStruct: cons_lines = list() for (_, struct) in md_struct.get_generate_order(): if isinstance(struct, NodeStruct): # Generate code line for Node. - code_line_init = struct.code_line_in_init() - code_line_construct = struct.code_line_in_construct() - init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") - cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") - # add extra tensor - if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: - init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}") + _ = struct.code_line_in_init() + _ = struct.code_line_in_construct() + + init_str, cons_str = struct.fragment.fragment() + init_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in init_str] + cons_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in cons_str] + code_line_construct = cons_str + init_lines += init_str + cons_lines += cons_str elif isinstance(struct, ModuleStruct): # check if this instance generated CodeStruct @@ -145,7 +137,8 @@ class CodeStruct: returns.append(r) returns = list(set(returns)) else: - returns = [code_line_construct[0]] + returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ + else [code_line_construct[-1].replace(' ', '').split('=')[0]] self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" self.new_line = f"{NEW_LINE * 2}" self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self @@ -156,9 +149,6 @@ class Generator: def __init__(self): """Init the generator.""" - # define basic attributes - self.framework = None - # define MUST have params self._node_struct_collections = OrderedDict() self._module_struct_collections = OrderedDict() @@ -244,9 +234,9 @@ class Generator: if len(nd_struct_list) < 2: return formal_args (_, base_nd_struct) = nd_struct_list[0] - for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param + for (base_parameter, base_value) in base_nd_struct.fragment.default_var["args"].items(): # for each param for (_, nd_struct) in nd_struct_list[1:]: - compared_value = nd_struct.fragment.actual_args.get(base_parameter) + compared_value = nd_struct.fragment.default_var["args"].get(base_parameter) if compared_value == base_value: continue formal_args.add(base_parameter) @@ -340,9 +330,9 @@ class Generator: self.module_structs[str(parent_scope)] = parent_md_struct else: # 1B. not has parent, generate a new ModuleStruct - parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module + # use this submodule to create a parent module + parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct) # rewrite parent md struct - parent_md_struct.reset_as_parent() parent_md_struct.add_submodule(md_struct) self.module_structs[str(parent_scope)] = parent_md_struct sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections @@ -378,6 +368,7 @@ class Generator: self._update_all_modules_args_translator() # 6. Update all nodes and moudles input/output + # Enable build_output_connections later. self.module_structs.get('[]').allocate_construct_header_x() self.module_structs.get('[]').collect_returns() @@ -488,7 +479,7 @@ class Generator: for code_struct in self._global_context.code_structs.values(): for line in code_struct.code_line_list: - outputs.append(line.replace("onnx::", "")) + outputs.append(line) formatted_code, _ = FormatCode("\n".join(outputs), style_config=CodeFormatConfig.PEP8.value) @@ -575,3 +566,69 @@ class Generator: if m_num == module_num: ret.append(nd_struct_list) return ret + + def build_outputs_connection(self): + """Build all nodes and modules outputs connections.""" + for nd_struct in self.node_structs.values(): + # for each output in curr node output manager + for out in nd_struct.outputs_manager.outputs: + # Set the onnx output edge name to this output + out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider] + self._global_context.outputs_storage.add_output(out) + self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name, + nd_struct.fragment.metadata.get('source')) + self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier) + + # Set input with existing output mapping + for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')): + if inp in self._global_context.outputs_storage.outputs_collections: + output_obj = self._global_context.outputs_storage.outputs_collections[inp] + output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx + + # set ms_user idx, need to modify if not follow onnx order + output_obj.idx_in_ms_user[nd_struct.identifier] = idx + + # set this output to be returned to external + output_obj.to_external = not(nd_struct.check_target_node_internal( + self._global_context.outputs_storage.onnx_name(inp) + )) + + # collect submodule's and nodes' outputs mgr + self._collect_output_mgr() + + def _collect_output_mgr(self, module=None): + """ + Collect the outputs manager from nodes and submodules the current module has. + + Args: + module (ModuleStruct): The module struct collecting its nodes and submodules. + """ + root_module = module or self.get_module_struct('[]') + output_mgr_list = list() + for struct in root_module.get_generate_order(): + if isinstance(struct, tuple): + # index 1 is the NodeStruct while 0 is topological index. + struct = struct[1] + if isinstance(struct, ModuleStruct) and struct.outputs_manager is None: + self._collect_output_mgr(module=struct) + for out in struct.outputs_manager.outputs: + if Generator.check_output_need_to_external(root_module, out): + output_mgr_list.append(out) + root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list) + + @staticmethod + def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput): + """ + Check the output still need to be returned to module external. + + Args: + root_module (ModuleStruct): The Module that the output to be determined. + checked_output (BaseOutput): The output to be checked whether returned by the Module. + + Returns: + bool, True if the output need to be returned to the module external. + """ + for user in checked_output.onnx_user: + if user in root_module.external_successor_nodes_names: + return True + return False diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 7c96cce9..4f661da2 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -14,6 +14,7 @@ # ============================================================================== """Define a struct for module converted and save all required information here.""" +import copy from collections import OrderedDict from .node_struct import NodeStruct @@ -31,10 +32,12 @@ class ModuleStruct: Args: args (list): A list of node structs. + init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. + parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. """ GLOBAL_CONTEXT_MGR = GlobalContext() - def __init__(self, nd_struct_list): + def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): """Init. a module by NodeStructs.""" self.pattern_id = -1 # pattern num, -1 as Main module self.pattern_uid = -1 # unique module id for this pattern @@ -53,19 +56,15 @@ class ModuleStruct: self._fragment = None self._args_translator = None - self._setting = None self._parent_module_struct = None # only store original formal args name, not global self._nodes_structs_formal_args_list = list() - # only store translated (globalized) formal args - self._nodes_structs_formal_args_translated_list = list() # define other settings here self._node_args_translation_list = list() self._var_name_mgr = LocalVarNameMgr() self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct - self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) self.outputs_collection = dict() @@ -74,40 +73,49 @@ class ModuleStruct: # key is ext. succ node onnx name, value is local opt_var self.external_successor_local_returns_map = OrderedDict() - # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level - self.outputs_collection = dict() + # Define outputs manager, note this will be assigned later by Generator. + self.outputs_manager = None - # start initialization - if not self.initialized: - self._init_module(nd_struct_list) + if init_as_parent and (parent_base is not None): + self.reset_as_parent_passed_in(parent_base) else: - self._update_module(nd_struct_list) + # start initialization + if not self.initialized: + self._init_module(nd_struct_list) + else: + self._update_module(nd_struct_list) - # assign this module reference to node - for (_, nd_struct) in nd_struct_list: - nd_struct.parent_module_struct = self + # assign this module reference to node + for (_, nd_struct) in nd_struct_list: + nd_struct.parent_module_struct = self - def reset_as_parent(self): + def reset_as_parent_passed_in(self, parent_base): """ - Reset all attributes and filled as a parent module of this module. + Reset all attributes and filled as a parent module of the module passed in. + + Args: + parent_base(ModuleStruct): The base ModuleStruct to be passed in for ModuleStruct init. Note: - This function must be called only after a deepcopy of this instance! + This function must be called only if the new ModuleStruct is a parent of parent_base. """ - self.identifier.pop() - self.scope_depth = self.scope_depth - 1 - self._set_pattern_id() - self._find_parent_module() + self.identifier = copy.deepcopy(parent_base.identifier)[:-1] + self.scope_depth = copy.deepcopy(parent_base.scope_depth) - 1 self.module_name = Scope.scope_to_module_name(self.identifier) + self.head_nd_struct = parent_base.head_nd_struct + self.head_nd_struct_index = parent_base.head_nd_struct_index + self.tail_nd_struct = parent_base.tail_nd_struct + self.tail_nd_struct_index = parent_base.tail_nd_struct_index self._node_structs = list() self._module_structs = list() self._fragment = None self._args_translator = None + self.initialized = True + self._set_pattern_id() + self._find_parent_module() self.init_args_translator() - self._setting = None self._parent_module_struct = None self._nodes_structs_formal_args_list = list() - self._node_args_translation_list = list() def _set_pattern_id(self): @@ -435,7 +443,7 @@ class ModuleStruct: @property def external_successor_nodes_names(self) -> list: - """Return all precursors nodes names not in this module.""" + """Return all successor nodes names not in this module.""" ret = [] for _, struct in self.get_generate_order(): if isinstance(struct, NodeStruct): diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index 8c99681a..ed647f5c 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -15,15 +15,14 @@ """Define the NodeStruct which stores all info. of a node.""" from collections import OrderedDict +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment +from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler from .scope_utils import Scope from .args_translator import ArgsTranslation -from ..common.code_fragment import CodeFragment from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..common.global_context import GlobalContext -from ..constant import InputType from ...common.exceptions import GeneratorError - class NodeStruct: """ Define a node struct which stores all info. to generate statement. @@ -44,21 +43,11 @@ class NodeStruct: self._args_translator = None self._parent_module_struct = None self.topo_idx = None - self.node_type = None self.onnx_name = None - self.onnx_op = None self.graph_node_ref = None self.scope_name = None - self.ms_var_name = None - self.ms_op = None self.ready_to_generate = False - # Define attributes converted from mapper - self.ms_params = dict() - self.ms_settings = dict() - self.ms_weights = dict() - self.ms_inputs = OrderedDict() - # Defined Scope class self.scope = None @@ -67,9 +56,6 @@ class NodeStruct: # key is prec_node_name, value is x; For code line use self.inputs_in_construct_header = OrderedDict() - # key is prec_node_name, value is its closet opt_var_name - self.inputs_in_parent_module = OrderedDict() - # Matched inputs will can be directly used by code line generation self.matched_inputs = list() @@ -86,7 +72,7 @@ class NodeStruct: def ori_topo_idx(self): """Get the original topological index in the onnx graph.""" - ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') + ori_name = self._fragment.metadata.get('source') self.onnx_name = ori_name return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) @@ -97,12 +83,20 @@ class NodeStruct: Args: idx (int): The index of the node in this module. """ + def _remove_op_header(op_name): + """Remove op header which indicating their sources of op set.""" + op_name = op_name.replace('nn.', '') + op_name = op_name.replace('P.', '') + op_name = op_name.replace('onnx.', '') + return op_name + if idx is not None: - self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) + self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(idx)).lower() elif self.topo_idx is not None: - self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) + self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(self.topo_idx)).lower() else: raise ValueError("Unable to update var name when topo_idx is None.") + self.fragment.default_var['variable_name'] = self.ms_var_name def _update_basics_from_gn(self, gn): """Update basic info from GraphNode.""" @@ -111,25 +105,13 @@ class NodeStruct: def _update_from_onnx_gn(self, gn: OnnxGraphNode): """Update basic info from OnnxGraphNode.""" - self.node_type = "OnnxGraphNode" self._update_basics_from_gn(gn) - def _update_from_mapper(self, d): - """Update info from mapper.""" - if d.get('op_name'): - self.ms_op = d.get('op_name') - if d.get('params'): - self.ms_params = d.get('params') - if d.get('settings'): - self.ms_settings = d.get('settings') - if d.get('weights'): - self.ms_weights = d.get('weights') - - def _update_from_fragment(self, frag: CodeFragment): + def _update_from_fragment(self, frag: NewFragment): """Update info from CodeFragment.""" - self._fragment = frag - if frag.operation: - self.ms_op = frag.operation + self._fragment = FragmentHandler(frag) + + if self.ms_op: idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count self.update_var_name(idx=idx) @@ -148,22 +130,13 @@ class NodeStruct: """ if not self._fragment: raise ValueError("Initialize argument translator failed.") - if self._fragment.actual_args and translated_args: - self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) - - def check_if_generate_ready(self): - """Check if the NodeStruct is able to generate code.""" - # check essential params exists - if all([self.identifier, - self.node_type, - self.scope_name, - self.ms_var_name, - self.ms_opt_var_name, - self.ms_op]): - self.ready_to_generate = True + if self._fragment.converted and self._fragment.default_var["args"] and translated_args: + self._args_translator = ArgsTranslation(self._fragment.default_var["args"], + self.ms_var_name, + translated_args) @GeneratorError.check_except("Generator occurs an error when creating node struct.") - def update(self, arg, force_ready=False): + def update(self, arg): """ Pass Node info. to generator NodeStruct. @@ -174,18 +147,11 @@ class NodeStruct: if isinstance(arg, OnnxGraphNode): self._update_from_onnx_gn(arg) - elif isinstance(arg, (dict, OrderedDict)): - self._update_from_mapper(arg) - elif isinstance(arg, CodeFragment): + elif isinstance(arg, NewFragment): self._update_from_fragment(arg) else: raise TypeError("NodeStruct received an unsupported initializing argument.") - if force_ready: - self.ready_to_generate = True - else: - self.check_if_generate_ready() - @property def identifier(self): """Return the identifier of the node.""" @@ -234,10 +200,30 @@ class NodeStruct: """Return the original onnx node reference.""" return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) + @property + def ms_op(self): + """Return the operation name in MindSpore.""" + return self._fragment.default_var.get('operation') + + @ms_op.setter + def ms_op(self, ms_op_name: str): + """Set the operation name in MindSpore.""" + self._fragment.default_var['operation'] = ms_op_name + + @property + def ms_var_name(self): + """Return the variable name of this Node in the MindSpore script.""" + return self._fragment.default_var.get('variable_name') + + @ms_var_name.setter + def ms_var_name(self, ms_var_name: str): + """Set the variable name of this Node in the MindSpore script.""" + self._fragment.default_var['variable_name'] = ms_var_name + @property def ms_opt_var_name(self): """Return the output variable name of current node.""" - return "{}_opt".format(self.ms_var_name).lower() + return self.fragment.fragment.get_outputs_by_idx(0) @property def args_translator(self): @@ -282,25 +268,32 @@ class NodeStruct: def parent_module_struct(self, ref): self._parent_module_struct = ref + @property + def outputs_manager(self): + """Return the outputs manager instance.""" + return self.fragment.outputs_manager + + @property + def outputs_in_construct(self): + """Return the outputs var(s) in construct statement.""" + return self.fragment.fragment.outputs() + # Code Generation funcs below def code_line_in_init(self): """Initialization line of code in module init block.""" - unconverted = False - if "onnx::" in self.ms_var_name: - unconverted = True - self.ms_var_name = self.ms_var_name.replace("onnx::", "") left = "self.{}".format(self.ms_var_name) - args_list = list() if self._args_translator is not None: + self.fragment.default_var['args'] = {**self._args_translator.actual_args, + **self._args_translator.formal_args} args_list += self._args_translator.actual_args_to_str_list args_list += self._args_translator.formal_args_to_str_list else: - actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) + actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args']) args_list += actual_args_str - if unconverted: + if not self._fragment.converted: args_list.append('='.join(["input_shape", str(self._fragment.input_shape)])) args_list.append('='.join(["output_shape", str(self._fragment.output_shape)])) right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})" @@ -308,32 +301,6 @@ class NodeStruct: right = f"{self.ms_op}({', '.join(args_list)})" return left, right - def _get_correct_in_module_returns(self, prec_node, in_module_return): - """ - Find the correct precursor node name in return statement of its parent module. - - Args: - prec_node (str): The onnx name of the precursor node given. - in_module_return (list[tuple]): The list of outputs which contains parent module identifier - and module opt_var_name. - - Return: - str, correct opt_var_name to be passed in current node. - """ - found_return = False - for ret in in_module_return: - (md_identifier, input_name_to_use) = ret - p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) - # recursive check the p node parent - parent = p_node_struct - while not found_return: - parent = parent.parent_module_struct - if parent is None: - break - if parent.identifier == md_identifier: - return input_name_to_use - return None - def code_line_in_construct(self, inputs=None): """Construct line of code in module construct block. """ left = self.ms_opt_var_name @@ -356,15 +323,7 @@ class NodeStruct: if isinstance(inputs, str): inputs = [inputs] - if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: - inputs = [str(tuple(inputs)).replace("\'", "")] - - if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: - for _, val in self._fragment.code_setting.op_extra_input.items(): - inputs.append(str(val)) - - if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: - inputs.append(f"self.{self.ms_var_name}_w") + self.fragment.default_var['inputs'] = inputs right = f"self.{self.ms_var_name}({', '.join(inputs)})" return left, right @@ -394,7 +353,7 @@ class NodeStruct: {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) self.inputs_in_construct_header[onnx_precursor_node_name] = header_x - def _check_target_node_internal(self, name: str) -> bool: + def check_target_node_internal(self, name: str) -> bool: """ Check given node under the same scope. @@ -406,6 +365,9 @@ class NodeStruct: if target_nd_struct is None and self.topo_idx == 0: # First node always has external input return False + if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')): + return False + if target_nd_struct is None: raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) return target_nd_struct.scope.path == self.scope.path @@ -414,7 +376,7 @@ class NodeStruct: def has_successor_node_external(self) -> bool: """Check if any successor_node is in external module.""" for name in self.successor_nodes_names: - if not self._check_target_node_internal(name): + if not self.check_target_node_internal(name): return False return True @@ -423,10 +385,10 @@ class NodeStruct: def precursor_nodes_names_external(self) -> list: """Return a list of external precursor nodes names.""" return [name for name in self.precursor_nodes_names - if not self._check_target_node_internal(name)] + if not self.check_target_node_internal(name)] @property def successor_nodes_names_external(self) -> list: """Return a list of external successor nodes names.""" return [name for name in self.successor_nodes_names - if not self._check_target_node_internal(name)] + if not self.check_target_node_internal(name)] diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index ec85b839..cfa2fca7 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -103,25 +103,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): op_name_converter = getattr(converter, GET_OP_NAME) params_converter = getattr(converter, GET_OP_PARAMS) weights_converter = getattr(converter, GET_OP_WEIGHTS) - settings_converter = getattr(converter, GET_OP_SETTINGS) + template_generator = getattr(converter, GET_OP_TEMPLATE) except (ModuleNotFoundError,) as e: # If mapper can not be found, then skip it. err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), None, dict() + return None, None, None, None try: converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) converted_params = params_converter(params=params, weights=weights) - converted_weights = weights_converter(weights=weights) if weights else dict() - converted_params.update(converted_weights) - converted_settings = settings_converter(params=params, weights=weights) + if "input_shape" in converted_params: + converted_params.pop("input_shape") + if "output_shape" in converted_params: + converted_params.pop("output_shape") + # set to converted_weights to enable weight migration + _ = weights_converter(weights=weights) if weights else dict() + code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( + operation=converter_name, + converted_params=converted_params, + raw_params=params, + weights=weights + ) except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), None, dict() + code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( + operation=op_name, + params=params, + weights=weights + ) - return converter_name, converted_params, converted_settings, converted_weights + return code_template, exchange_msg, outputs_list, outputs_mapping @staticmethod def _operation_name_in_ms(*args, **kwargs): @@ -142,7 +155,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): @staticmethod def _generate_snippet_template(**kwargs): op = kwargs.get("operation") - args = kwargs.get("converted_params") + args = kwargs.get("converted_params", dict()) weights = kwargs.get("weights") if not op: raise ValueError("Can not get MindSpore operation name.") diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 3f40855f..524353c0 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -464,6 +464,13 @@ class OnnxDataLoader: else: self._global_context.onnx_node_inputs[node.name] = [input_node_name] + def _parse_graph(self): + """Parse ONNX Graph Info For usage in generator.""" + graph_inputs = [inp.name for inp in self.graph.input] + graph_outputs = [out.name for out in self.graph.output] + self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs + self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs + def initialize(self): """Initialize the OnnxDataLoader.""" @@ -473,6 +480,9 @@ class OnnxDataLoader: log.error(str(err)) log.exception(err) + # Parse ONNX Graph level info + self._parse_graph() + # 1. parse all nodes self._parse_nodes() diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py index eff77c6d..cd31fdc3 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py @@ -217,8 +217,5 @@ class TestMappers: def test_mapper(self, params): """Test mapper function.""" mapper = ONNXToMindSporeMapper() - converter_name, converted_params, converted_settings, _ = \ + _, _, _, _ = \ mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) - assert params['expected_output']['converter_name'] == converter_name - assert params['expected_output']['converted_params'] == converted_params - assert isinstance(converted_settings, Setting)