fix problem where incorrect logic for nodestruct's check target nodes internal when target node is graph inputs update copyright add a output mgr adopt the generator with new fragment exchange msg. still need to adopt old ver. code_settings etc. outputs mgr dev; adapt the new fragment. temp disable extra nodes and weights adapt the NewFragment; re-imlement the module struct reset method update ut mapper test for new fragmenttags/v1.2.0-rc1
| @@ -27,8 +27,9 @@ class Fragment(abc.ABC): | |||
| Args: | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| input_shape (tuple): The input shape of the node. | |||
| output_shape (tuple): The output shape of the node. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | |||
| @@ -46,6 +47,7 @@ class Fragment(abc.ABC): | |||
| @property | |||
| def code_setting(self): | |||
| """Code Setting getter.""" | |||
| return self._code_setting | |||
| @property | |||
| @@ -152,10 +154,12 @@ class Fragment(abc.ABC): | |||
| @property | |||
| def input_shape(self): | |||
| """Return the input shape.""" | |||
| return self._input_shape | |||
| @property | |||
| def output_shape(self): | |||
| """Return the output shape.""" | |||
| return self._output_shape | |||
| @@ -196,6 +200,7 @@ class CodeFragment(Fragment): | |||
| @property | |||
| def trainable_params(self): | |||
| """Return the trainable parameters.""" | |||
| return self._trainable_params | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Define GlobalContext class to save required resources during whole conversion procedure.""" | |||
| from collections import OrderedDict | |||
| from .outputs import OutputStorage | |||
| class Singleton(type): | |||
| @@ -45,6 +46,7 @@ class GlobalContext(metaclass=Singleton): | |||
| self.onnx_node_name_to_topo_idx = dict() | |||
| self.onnx_node_inputs = dict() | |||
| self._onnx_tensors_collection = dict() | |||
| self.onnx_graph_info = dict() | |||
| # Define data stored from generator | |||
| # Key as Node Identifier | |||
| @@ -72,6 +74,8 @@ class GlobalContext(metaclass=Singleton): | |||
| # key is target node (which use this opt), value is opt_var_name | |||
| self.extra_input_dict = dict() | |||
| self.outputs_storage = OutputStorage() | |||
| def get_onnx_node_from_identifier(self, identifier): | |||
| """Return an OnnxUtils defined node by its identifier.""" | |||
| onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) | |||
| @@ -0,0 +1,200 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define basic classes for generator use.""" | |||
| import abc | |||
| import copy | |||
| from typing import Union, Iterable | |||
| class BaseOutput: | |||
| """ | |||
| Define the class of output providing a universal nodes' and modules' output data collection. | |||
| Args: | |||
| output_mapping (tuple[tuple]): The mapping of outputs from onnx to mindspore. | |||
| """ | |||
| def __init__(self, output_mapping) -> None: | |||
| super(BaseOutput).__init__() | |||
| self.idx_in_ms_provider = output_mapping[0] | |||
| self.idx_in_onnx_provider = output_mapping[1] | |||
| # For multi users, key as user and value as index | |||
| self.idx_in_ms_user = dict() | |||
| self.idx_in_onnx_user = dict() | |||
| # The following attributes to be set by who referenced this object. | |||
| self.onnx_edge_name = None | |||
| self.to_external = False | |||
| @property | |||
| def ms_user(self): | |||
| """Return the output's user in the MindSpore.""" | |||
| return self.idx_in_ms_user.keys() | |||
| @property | |||
| def onnx_user(self): | |||
| """Return the output's user in the ONNX.""" | |||
| return self.idx_in_onnx_user.keys() | |||
| def deepcopy(self): | |||
| """Return a deepcopy of self instance.""" | |||
| return copy.deepcopy(self) | |||
| class BaseOutputManager(abc.ABC): | |||
| """ | |||
| Base Output Manager class. | |||
| Args: | |||
| output_mappings (list): A list of output mapping. | |||
| """ | |||
| def __init__(self, output_mappings): | |||
| if isinstance(self.__class__, ModuleOutputManager): | |||
| return | |||
| self._base_output_list = list() | |||
| # init base output obj | |||
| for mapping in output_mappings: | |||
| obj = BaseOutput(mapping) | |||
| self._base_output_list.append(obj) | |||
| @property | |||
| def outputs(self): | |||
| """Return the list of BaseOutput in this manager.""" | |||
| return self._base_output_list | |||
| @outputs.setter | |||
| def outputs(self, val: list): | |||
| """Set the list of BaseOutput in this manager.""" | |||
| for v in val: | |||
| if not isinstance(v, BaseOutput): | |||
| raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.") | |||
| self._base_output_list = val | |||
| @abc.abstractmethod | |||
| def deepcopy(self): | |||
| """Return the deepcopy of this instance.""" | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.outputs = list() | |||
| for out in self._base_output_list: | |||
| result.outputs.append(out.deepcopy()) | |||
| return result | |||
| class NodeOutputManager(BaseOutputManager): | |||
| """ | |||
| Node Output Manager class. | |||
| Args: | |||
| identifier (str): The identifier of the node. | |||
| output_mappings (list): A list of the output mapping. | |||
| """ | |||
| def __init__(self, identifier, output_mappings=None) -> None: | |||
| super(NodeOutputManager, self).__init__(output_mappings) | |||
| self.identifier = identifier | |||
| def deepcopy(self): | |||
| new_mgr = super().deepcopy() | |||
| new_mgr.identifier = self.identifier | |||
| return new_mgr | |||
| class ModuleOutputManager(BaseOutputManager): | |||
| """ | |||
| Module Output Manager class. | |||
| Args: | |||
| identifier (str): The identifier of the module. | |||
| output_mappings (list): a list of output mapping | |||
| """ | |||
| def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None: | |||
| super(ModuleOutputManager, self).__init__(None) | |||
| self.identifier = identifier | |||
| self._return_list_counter = 0 | |||
| self._base_output_list = list() | |||
| if isinstance(base_out, BaseOutput): | |||
| self._base_output_list.append(base_out) | |||
| else: | |||
| self._base_output_list += base_out | |||
| @property | |||
| def return_num(self): | |||
| """Return the number of outputs to be returned.""" | |||
| return self._return_list_counter | |||
| @return_num.setter | |||
| def return_num(self, num: int): | |||
| """Set the number of outputs to be returned.""" | |||
| self._return_list_counter = num | |||
| def deepcopy(self): | |||
| """Return a deepcopy of current instance.""" | |||
| new_mgr = super().deepcopy() | |||
| new_mgr.identifier = self.identifier | |||
| new_mgr.return_num = self._return_list_counter | |||
| return new_mgr | |||
| class OutputStorage: | |||
| """A class saves all outputs.""" | |||
| def __init__(self): | |||
| self._base_output_edge_to_instance = dict() | |||
| self._base_output_edge_to_onnx_node_name = dict() | |||
| self._base_output_edge_to_ms_identifier = dict() | |||
| @property | |||
| def outputs_collections(self) -> dict: | |||
| """Return the dict of edge name to output instance.""" | |||
| return self._base_output_edge_to_instance | |||
| def onnx_name(self, output_edge) -> str: | |||
| """Return the dict of edge name to onnx node name.""" | |||
| return self._base_output_edge_to_onnx_node_name.get(output_edge) | |||
| def node_identifier(self, output_edge): | |||
| """Return the dict of edge name to node identifier.""" | |||
| return self._base_output_edge_to_ms_identifier.get(output_edge) | |||
| def add_output(self, out: BaseOutput) -> str: | |||
| """ | |||
| Add a BaseOutput instance to the storage. | |||
| Args: | |||
| out (BaseOutput): The BaseOutput instance. | |||
| """ | |||
| if out.onnx_edge_name: | |||
| self._base_output_edge_to_instance[out.onnx_edge_name] = out | |||
| else: | |||
| raise ValueError("Unable to add a BaseOutput instance with unknown ONNX edge.") | |||
| def add_onnx_node_name(self, edge: str, onnx_node_name: str): | |||
| """ | |||
| Add the onnx node name with the edge name. | |||
| Args: | |||
| edge (str): The edge name of this output. | |||
| onnx_node_name (str): The onnx node which has the edge. | |||
| """ | |||
| self._base_output_edge_to_onnx_node_name[edge] = onnx_node_name | |||
| def add_ms_identifier(self, edge: str, ms_identifier: str): | |||
| """ | |||
| Add the node identifier with the edge name. | |||
| Args: | |||
| edge (str): The edge name of this output. | |||
| ms_identifier (str): The identifier of the node which has the edge. | |||
| """ | |||
| self._base_output_edge_to_ms_identifier[edge] = ms_identifier | |||
| @@ -18,9 +18,10 @@ __all__ = ["batch_add_nodes"] | |||
| import re | |||
| import copy | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| from .generator import Generator, CodeStruct | |||
| from ..common.code_fragment import CodeFragment, NewFragment | |||
| from ..common.outputs import NodeOutputManager | |||
| from ..constant import ExchangeMessageKeywords | |||
| def _tf_model_node_name_reformat(node, node_name): | |||
| @@ -56,6 +57,7 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: | |||
| """ | |||
| generator_inst = Generator() | |||
| external_inputs = graph_obj.user_provided_input_nodes | |||
| for node_name in graph_obj.nodes_in_topological_order: | |||
| node_inst = graph_obj.get_node(node_name) | |||
| node_input = graph_obj.get_input_shape(node_name) | |||
| @@ -66,16 +68,11 @@ def batch_add_nodes(graph_obj, mapper) -> Generator: | |||
| node_name = node_name_with_scope | |||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||
| op_name, params, settings, weights = _convert_params(node_inst, mapper) | |||
| generator_inst.add_node( | |||
| node_name, | |||
| node_instance=node_inst, | |||
| node_fragment=CodeFragment(op_name, params, | |||
| settings, | |||
| node_inst.input_shape, | |||
| node_inst.output_shape, | |||
| weights) | |||
| ) | |||
| code_template, exchange_msg, outputs_lst, outputs_mapping = _convert_params(node_inst, mapper, external_inputs) | |||
| outputs_mapping = NodeOutputManager(node_name, output_mappings=outputs_mapping) | |||
| fragment = NewFragment(data_entity=exchange_msg, code_template=code_template, | |||
| outputs=outputs_lst, outputs_mapping=outputs_mapping) | |||
| generator_inst.add_node(node_name, node_instance=node_inst, node_fragment=fragment) | |||
| return generator_inst | |||
| @@ -105,13 +102,14 @@ def _supply_graph_info(node, external_inputs): | |||
| } | |||
| def _convert_params(node, mapper): | |||
| def _convert_params(node, mapper, external_inputs): | |||
| """ | |||
| Call mapper to convert node's params from ONNX to MindSpore. | |||
| Args: | |||
| node (GraphNode): Our defined GraphNode instance. | |||
| mapper (Mapper): The mapper instance which indicating conversion method. | |||
| external_inputs (list[str]): External inputs provided by users. | |||
| Returns: | |||
| tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, | |||
| @@ -121,18 +119,11 @@ def _convert_params(node, mapper): | |||
| params.update({"input_shape": node.input_shape, | |||
| "output_shape": node.output_shape}) | |||
| op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, | |||
| params=params, | |||
| weights=node.weight) | |||
| if "input_shape" in ms_params: | |||
| ms_params.pop("input_shape") | |||
| if "output_shape" in ms_params: | |||
| ms_params.pop("output_shape") | |||
| if op_in_ms: | |||
| return op_in_ms, ms_params, ms_settings, weights | |||
| return node.op_name, node.node_params, dict(), dict() | |||
| code_template, exchange_msg, outputs_lst, outputs_order_mapping = mapper.convert(op_name=node.op_name, | |||
| params=params, | |||
| weights=node.weight) | |||
| exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs) | |||
| return code_template, exchange_msg, outputs_lst, outputs_order_mapping | |||
| def _combine_external_inputs_with_precursor_nodes(node, external_inputs): | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Miscellaneous Fragment related classes and functions. """ | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| class FragmentHandler: | |||
| """ | |||
| Define a handler to process the infomation contained by Fragment. | |||
| Args: | |||
| fragment (NewFragment): The refactored fragment class. | |||
| """ | |||
| def __init__(self, fragment: NewFragment): | |||
| self._fragment = fragment | |||
| # set the var in the fragment to be load and save. | |||
| self._target_var = "var_0" | |||
| @property | |||
| def target_var(self): | |||
| """Return the target var name the handler currently set to be read.""" | |||
| return self._target_var | |||
| @target_var.setter | |||
| def target_var(self, target): | |||
| """Set the target var the handler will read.""" | |||
| if not target in self.exchange_msg.keys(): | |||
| raise ValueError(f"Unable to set target var {target} where fragment does not have it.") | |||
| self._target_var = target | |||
| @property | |||
| def fragment(self): | |||
| """Return the fragment instance the handler currently processed.""" | |||
| return self._fragment | |||
| @property | |||
| def converted(self): | |||
| """Return the status of the op successfully converted.""" | |||
| return bool(self._fragment.exchange_msg) | |||
| # The following section is intended for Fragment exchange message. | |||
| @property | |||
| def exchange_msg(self): | |||
| """Return the exchange message dictionary the fragment contains.""" | |||
| return self._fragment.exchange_msg | |||
| @property | |||
| def var(self): | |||
| """Return the var dictionary the handler currently set to be processed.""" | |||
| try: | |||
| return self.exchange_msg.get(self.target_var) | |||
| except AttributeError: | |||
| return None | |||
| @property | |||
| def default_var(self): | |||
| """Return the default var dictionary the handler processed.""" | |||
| try: | |||
| return self.exchange_msg.get("var_0") | |||
| except AttributeError: | |||
| return None | |||
| # For metadata | |||
| @property | |||
| def metadata(self): | |||
| """Return the metadata of the onnx node info dictionary.""" | |||
| return self._fragment.exchange_msg.get("metadata") | |||
| @property | |||
| def input_shape(self): | |||
| """Return the input shape of this node.""" | |||
| return self.metadata.get('inputs_shape') | |||
| @property | |||
| def output_shape(self): | |||
| """Return the output shape of this node.""" | |||
| return self.metadata.get('outputs_shape') | |||
| # For outputs | |||
| @property | |||
| def outputs_manager(self): | |||
| """Return the outputs manager of this node.""" | |||
| return self._fragment.outputs_mapping | |||
| @@ -23,6 +23,7 @@ from .node_struct import NodeStruct | |||
| from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ..common.outputs import BaseOutput, ModuleOutputManager | |||
| from ...common.exceptions import GeneratorError | |||
| from ..common.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||
| @@ -39,21 +40,10 @@ class CodeStruct: | |||
| def __init__(self, struct, repeated_submodules=None): | |||
| """Initialize the CodeStruct.""" | |||
| self.output_order = None # output order | |||
| self.input = None # opt_var_name for prev. node | |||
| self.extra_input = list() # extra_input(s) at construct method args | |||
| self.output = None # opt_var_name for next node | |||
| self.extra_output = list() # extra_output(s) | |||
| self.extra_comment = None # comments for this code line / block. | |||
| self.code_line_list = list() # list of code line, a item is a line. | |||
| self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module | |||
| self.formal_args_collections = None | |||
| if isinstance(struct, NodeStruct): | |||
| self.output_order = struct.topo_idx | |||
| if isinstance(struct, ModuleStruct): | |||
| self.output_order = struct.head_nd_struct_index | |||
| self._generate_from_module_struct(struct, repeated_submodules) | |||
| def _add_line(self, s): | |||
| @@ -102,13 +92,15 @@ class CodeStruct: | |||
| cons_lines = list() | |||
| for (_, struct) in md_struct.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): # Generate code line for Node. | |||
| code_line_init = struct.code_line_in_init() | |||
| code_line_construct = struct.code_line_in_construct() | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||
| # add extra tensor | |||
| if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}") | |||
| _ = struct.code_line_in_init() | |||
| _ = struct.code_line_in_construct() | |||
| init_str, cons_str = struct.fragment.fragment() | |||
| init_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in init_str] | |||
| cons_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in cons_str] | |||
| code_line_construct = cons_str | |||
| init_lines += init_str | |||
| cons_lines += cons_str | |||
| elif isinstance(struct, ModuleStruct): | |||
| # check if this instance generated CodeStruct | |||
| @@ -145,7 +137,8 @@ class CodeStruct: | |||
| returns.append(r) | |||
| returns = list(set(returns)) | |||
| else: | |||
| returns = [code_line_construct[0]] | |||
| returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ | |||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||
| @@ -156,9 +149,6 @@ class Generator: | |||
| def __init__(self): | |||
| """Init the generator.""" | |||
| # define basic attributes | |||
| self.framework = None | |||
| # define MUST have params | |||
| self._node_struct_collections = OrderedDict() | |||
| self._module_struct_collections = OrderedDict() | |||
| @@ -244,9 +234,9 @@ class Generator: | |||
| if len(nd_struct_list) < 2: | |||
| return formal_args | |||
| (_, base_nd_struct) = nd_struct_list[0] | |||
| for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param | |||
| for (base_parameter, base_value) in base_nd_struct.fragment.default_var["args"].items(): # for each param | |||
| for (_, nd_struct) in nd_struct_list[1:]: | |||
| compared_value = nd_struct.fragment.actual_args.get(base_parameter) | |||
| compared_value = nd_struct.fragment.default_var["args"].get(base_parameter) | |||
| if compared_value == base_value: | |||
| continue | |||
| formal_args.add(base_parameter) | |||
| @@ -340,9 +330,9 @@ class Generator: | |||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||
| else: | |||
| # 1B. not has parent, generate a new ModuleStruct | |||
| parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module | |||
| # use this submodule to create a parent module | |||
| parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct) | |||
| # rewrite parent md struct | |||
| parent_md_struct.reset_as_parent() | |||
| parent_md_struct.add_submodule(md_struct) | |||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||
| sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections | |||
| @@ -378,6 +368,7 @@ class Generator: | |||
| self._update_all_modules_args_translator() | |||
| # 6. Update all nodes and moudles input/output | |||
| # Enable build_output_connections later. | |||
| self.module_structs.get('[]').allocate_construct_header_x() | |||
| self.module_structs.get('[]').collect_returns() | |||
| @@ -488,7 +479,7 @@ class Generator: | |||
| for code_struct in self._global_context.code_structs.values(): | |||
| for line in code_struct.code_line_list: | |||
| outputs.append(line.replace("onnx::", "")) | |||
| outputs.append(line) | |||
| formatted_code, _ = FormatCode("\n".join(outputs), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| @@ -575,3 +566,69 @@ class Generator: | |||
| if m_num == module_num: | |||
| ret.append(nd_struct_list) | |||
| return ret | |||
| def build_outputs_connection(self): | |||
| """Build all nodes and modules outputs connections.""" | |||
| for nd_struct in self.node_structs.values(): | |||
| # for each output in curr node output manager | |||
| for out in nd_struct.outputs_manager.outputs: | |||
| # Set the onnx output edge name to this output | |||
| out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider] | |||
| self._global_context.outputs_storage.add_output(out) | |||
| self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name, | |||
| nd_struct.fragment.metadata.get('source')) | |||
| self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier) | |||
| # Set input with existing output mapping | |||
| for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')): | |||
| if inp in self._global_context.outputs_storage.outputs_collections: | |||
| output_obj = self._global_context.outputs_storage.outputs_collections[inp] | |||
| output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx | |||
| # set ms_user idx, need to modify if not follow onnx order | |||
| output_obj.idx_in_ms_user[nd_struct.identifier] = idx | |||
| # set this output to be returned to external | |||
| output_obj.to_external = not(nd_struct.check_target_node_internal( | |||
| self._global_context.outputs_storage.onnx_name(inp) | |||
| )) | |||
| # collect submodule's and nodes' outputs mgr | |||
| self._collect_output_mgr() | |||
| def _collect_output_mgr(self, module=None): | |||
| """ | |||
| Collect the outputs manager from nodes and submodules the current module has. | |||
| Args: | |||
| module (ModuleStruct): The module struct collecting its nodes and submodules. | |||
| """ | |||
| root_module = module or self.get_module_struct('[]') | |||
| output_mgr_list = list() | |||
| for struct in root_module.get_generate_order(): | |||
| if isinstance(struct, tuple): | |||
| # index 1 is the NodeStruct while 0 is topological index. | |||
| struct = struct[1] | |||
| if isinstance(struct, ModuleStruct) and struct.outputs_manager is None: | |||
| self._collect_output_mgr(module=struct) | |||
| for out in struct.outputs_manager.outputs: | |||
| if Generator.check_output_need_to_external(root_module, out): | |||
| output_mgr_list.append(out) | |||
| root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list) | |||
| @staticmethod | |||
| def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput): | |||
| """ | |||
| Check the output still need to be returned to module external. | |||
| Args: | |||
| root_module (ModuleStruct): The Module that the output to be determined. | |||
| checked_output (BaseOutput): The output to be checked whether returned by the Module. | |||
| Returns: | |||
| bool, True if the output need to be returned to the module external. | |||
| """ | |||
| for user in checked_output.onnx_user: | |||
| if user in root_module.external_successor_nodes_names: | |||
| return True | |||
| return False | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Define a struct for module converted and save all required information here.""" | |||
| import copy | |||
| from collections import OrderedDict | |||
| from .node_struct import NodeStruct | |||
| @@ -31,10 +32,12 @@ class ModuleStruct: | |||
| Args: | |||
| args (list): A list of node structs. | |||
| init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. | |||
| parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. | |||
| """ | |||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||
| def __init__(self, nd_struct_list): | |||
| def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): | |||
| """Init. a module by NodeStructs.""" | |||
| self.pattern_id = -1 # pattern num, -1 as Main module | |||
| self.pattern_uid = -1 # unique module id for this pattern | |||
| @@ -53,19 +56,15 @@ class ModuleStruct: | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self._setting = None | |||
| self._parent_module_struct = None | |||
| # only store original formal args name, not global | |||
| self._nodes_structs_formal_args_list = list() | |||
| # only store translated (globalized) formal args | |||
| self._nodes_structs_formal_args_translated_list = list() | |||
| # define other settings here | |||
| self._node_args_translation_list = list() | |||
| self._var_name_mgr = LocalVarNameMgr() | |||
| self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name | |||
| self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct | |||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||
| # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) | |||
| self.outputs_collection = dict() | |||
| @@ -74,40 +73,49 @@ class ModuleStruct: | |||
| # key is ext. succ node onnx name, value is local opt_var | |||
| self.external_successor_local_returns_map = OrderedDict() | |||
| # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level | |||
| self.outputs_collection = dict() | |||
| # Define outputs manager, note this will be assigned later by Generator. | |||
| self.outputs_manager = None | |||
| # start initialization | |||
| if not self.initialized: | |||
| self._init_module(nd_struct_list) | |||
| if init_as_parent and (parent_base is not None): | |||
| self.reset_as_parent_passed_in(parent_base) | |||
| else: | |||
| self._update_module(nd_struct_list) | |||
| # start initialization | |||
| if not self.initialized: | |||
| self._init_module(nd_struct_list) | |||
| else: | |||
| self._update_module(nd_struct_list) | |||
| # assign this module reference to node | |||
| for (_, nd_struct) in nd_struct_list: | |||
| nd_struct.parent_module_struct = self | |||
| # assign this module reference to node | |||
| for (_, nd_struct) in nd_struct_list: | |||
| nd_struct.parent_module_struct = self | |||
| def reset_as_parent(self): | |||
| def reset_as_parent_passed_in(self, parent_base): | |||
| """ | |||
| Reset all attributes and filled as a parent module of this module. | |||
| Reset all attributes and filled as a parent module of the module passed in. | |||
| Args: | |||
| parent_base(ModuleStruct): The base ModuleStruct to be passed in for ModuleStruct init. | |||
| Note: | |||
| This function must be called only after a deepcopy of this instance! | |||
| This function must be called only if the new ModuleStruct is a parent of parent_base. | |||
| """ | |||
| self.identifier.pop() | |||
| self.scope_depth = self.scope_depth - 1 | |||
| self._set_pattern_id() | |||
| self._find_parent_module() | |||
| self.identifier = copy.deepcopy(parent_base.identifier)[:-1] | |||
| self.scope_depth = copy.deepcopy(parent_base.scope_depth) - 1 | |||
| self.module_name = Scope.scope_to_module_name(self.identifier) | |||
| self.head_nd_struct = parent_base.head_nd_struct | |||
| self.head_nd_struct_index = parent_base.head_nd_struct_index | |||
| self.tail_nd_struct = parent_base.tail_nd_struct | |||
| self.tail_nd_struct_index = parent_base.tail_nd_struct_index | |||
| self._node_structs = list() | |||
| self._module_structs = list() | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self.initialized = True | |||
| self._set_pattern_id() | |||
| self._find_parent_module() | |||
| self.init_args_translator() | |||
| self._setting = None | |||
| self._parent_module_struct = None | |||
| self._nodes_structs_formal_args_list = list() | |||
| self._node_args_translation_list = list() | |||
| def _set_pattern_id(self): | |||
| @@ -435,7 +443,7 @@ class ModuleStruct: | |||
| @property | |||
| def external_successor_nodes_names(self) -> list: | |||
| """Return all precursors nodes names not in this module.""" | |||
| """Return all successor nodes names not in this module.""" | |||
| ret = [] | |||
| for _, struct in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): | |||
| @@ -15,15 +15,14 @@ | |||
| """Define the NodeStruct which stores all info. of a node.""" | |||
| from collections import OrderedDict | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler | |||
| from .scope_utils import Scope | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import InputType | |||
| from ...common.exceptions import GeneratorError | |||
| class NodeStruct: | |||
| """ | |||
| Define a node struct which stores all info. to generate statement. | |||
| @@ -44,21 +43,11 @@ class NodeStruct: | |||
| self._args_translator = None | |||
| self._parent_module_struct = None | |||
| self.topo_idx = None | |||
| self.node_type = None | |||
| self.onnx_name = None | |||
| self.onnx_op = None | |||
| self.graph_node_ref = None | |||
| self.scope_name = None | |||
| self.ms_var_name = None | |||
| self.ms_op = None | |||
| self.ready_to_generate = False | |||
| # Define attributes converted from mapper | |||
| self.ms_params = dict() | |||
| self.ms_settings = dict() | |||
| self.ms_weights = dict() | |||
| self.ms_inputs = OrderedDict() | |||
| # Defined Scope class | |||
| self.scope = None | |||
| @@ -67,9 +56,6 @@ class NodeStruct: | |||
| # key is prec_node_name, value is x; For code line use | |||
| self.inputs_in_construct_header = OrderedDict() | |||
| # key is prec_node_name, value is its closet opt_var_name | |||
| self.inputs_in_parent_module = OrderedDict() | |||
| # Matched inputs will can be directly used by code line generation | |||
| self.matched_inputs = list() | |||
| @@ -86,7 +72,7 @@ class NodeStruct: | |||
| def ori_topo_idx(self): | |||
| """Get the original topological index in the onnx graph.""" | |||
| ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') | |||
| ori_name = self._fragment.metadata.get('source') | |||
| self.onnx_name = ori_name | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) | |||
| @@ -97,12 +83,20 @@ class NodeStruct: | |||
| Args: | |||
| idx (int): The index of the node in this module. | |||
| """ | |||
| def _remove_op_header(op_name): | |||
| """Remove op header which indicating their sources of op set.""" | |||
| op_name = op_name.replace('nn.', '') | |||
| op_name = op_name.replace('P.', '') | |||
| op_name = op_name.replace('onnx.', '') | |||
| return op_name | |||
| if idx is not None: | |||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) | |||
| self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(idx)).lower() | |||
| elif self.topo_idx is not None: | |||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) | |||
| self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(self.topo_idx)).lower() | |||
| else: | |||
| raise ValueError("Unable to update var name when topo_idx is None.") | |||
| self.fragment.default_var['variable_name'] = self.ms_var_name | |||
| def _update_basics_from_gn(self, gn): | |||
| """Update basic info from GraphNode.""" | |||
| @@ -111,25 +105,13 @@ class NodeStruct: | |||
| def _update_from_onnx_gn(self, gn: OnnxGraphNode): | |||
| """Update basic info from OnnxGraphNode.""" | |||
| self.node_type = "OnnxGraphNode" | |||
| self._update_basics_from_gn(gn) | |||
| def _update_from_mapper(self, d): | |||
| """Update info from mapper.""" | |||
| if d.get('op_name'): | |||
| self.ms_op = d.get('op_name') | |||
| if d.get('params'): | |||
| self.ms_params = d.get('params') | |||
| if d.get('settings'): | |||
| self.ms_settings = d.get('settings') | |||
| if d.get('weights'): | |||
| self.ms_weights = d.get('weights') | |||
| def _update_from_fragment(self, frag: CodeFragment): | |||
| def _update_from_fragment(self, frag: NewFragment): | |||
| """Update info from CodeFragment.""" | |||
| self._fragment = frag | |||
| if frag.operation: | |||
| self.ms_op = frag.operation | |||
| self._fragment = FragmentHandler(frag) | |||
| if self.ms_op: | |||
| idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count | |||
| self.update_var_name(idx=idx) | |||
| @@ -148,22 +130,13 @@ class NodeStruct: | |||
| """ | |||
| if not self._fragment: | |||
| raise ValueError("Initialize argument translator failed.") | |||
| if self._fragment.actual_args and translated_args: | |||
| self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) | |||
| def check_if_generate_ready(self): | |||
| """Check if the NodeStruct is able to generate code.""" | |||
| # check essential params exists | |||
| if all([self.identifier, | |||
| self.node_type, | |||
| self.scope_name, | |||
| self.ms_var_name, | |||
| self.ms_opt_var_name, | |||
| self.ms_op]): | |||
| self.ready_to_generate = True | |||
| if self._fragment.converted and self._fragment.default_var["args"] and translated_args: | |||
| self._args_translator = ArgsTranslation(self._fragment.default_var["args"], | |||
| self.ms_var_name, | |||
| translated_args) | |||
| @GeneratorError.check_except("Generator occurs an error when creating node struct.") | |||
| def update(self, arg, force_ready=False): | |||
| def update(self, arg): | |||
| """ | |||
| Pass Node info. to generator NodeStruct. | |||
| @@ -174,18 +147,11 @@ class NodeStruct: | |||
| if isinstance(arg, OnnxGraphNode): | |||
| self._update_from_onnx_gn(arg) | |||
| elif isinstance(arg, (dict, OrderedDict)): | |||
| self._update_from_mapper(arg) | |||
| elif isinstance(arg, CodeFragment): | |||
| elif isinstance(arg, NewFragment): | |||
| self._update_from_fragment(arg) | |||
| else: | |||
| raise TypeError("NodeStruct received an unsupported initializing argument.") | |||
| if force_ready: | |||
| self.ready_to_generate = True | |||
| else: | |||
| self.check_if_generate_ready() | |||
| @property | |||
| def identifier(self): | |||
| """Return the identifier of the node.""" | |||
| @@ -234,10 +200,30 @@ class NodeStruct: | |||
| """Return the original onnx node reference.""" | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) | |||
| @property | |||
| def ms_op(self): | |||
| """Return the operation name in MindSpore.""" | |||
| return self._fragment.default_var.get('operation') | |||
| @ms_op.setter | |||
| def ms_op(self, ms_op_name: str): | |||
| """Set the operation name in MindSpore.""" | |||
| self._fragment.default_var['operation'] = ms_op_name | |||
| @property | |||
| def ms_var_name(self): | |||
| """Return the variable name of this Node in the MindSpore script.""" | |||
| return self._fragment.default_var.get('variable_name') | |||
| @ms_var_name.setter | |||
| def ms_var_name(self, ms_var_name: str): | |||
| """Set the variable name of this Node in the MindSpore script.""" | |||
| self._fragment.default_var['variable_name'] = ms_var_name | |||
| @property | |||
| def ms_opt_var_name(self): | |||
| """Return the output variable name of current node.""" | |||
| return "{}_opt".format(self.ms_var_name).lower() | |||
| return self.fragment.fragment.get_outputs_by_idx(0) | |||
| @property | |||
| def args_translator(self): | |||
| @@ -282,25 +268,32 @@ class NodeStruct: | |||
| def parent_module_struct(self, ref): | |||
| self._parent_module_struct = ref | |||
| @property | |||
| def outputs_manager(self): | |||
| """Return the outputs manager instance.""" | |||
| return self.fragment.outputs_manager | |||
| @property | |||
| def outputs_in_construct(self): | |||
| """Return the outputs var(s) in construct statement.""" | |||
| return self.fragment.fragment.outputs() | |||
| # Code Generation funcs below | |||
| def code_line_in_init(self): | |||
| """Initialization line of code in module init block.""" | |||
| unconverted = False | |||
| if "onnx::" in self.ms_var_name: | |||
| unconverted = True | |||
| self.ms_var_name = self.ms_var_name.replace("onnx::", "") | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| if self._args_translator is not None: | |||
| self.fragment.default_var['args'] = {**self._args_translator.actual_args, | |||
| **self._args_translator.formal_args} | |||
| args_list += self._args_translator.actual_args_to_str_list | |||
| args_list += self._args_translator.formal_args_to_str_list | |||
| else: | |||
| actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) | |||
| actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args']) | |||
| args_list += actual_args_str | |||
| if unconverted: | |||
| if not self._fragment.converted: | |||
| args_list.append('='.join(["input_shape", str(self._fragment.input_shape)])) | |||
| args_list.append('='.join(["output_shape", str(self._fragment.output_shape)])) | |||
| right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})" | |||
| @@ -308,32 +301,6 @@ class NodeStruct: | |||
| right = f"{self.ms_op}({', '.join(args_list)})" | |||
| return left, right | |||
| def _get_correct_in_module_returns(self, prec_node, in_module_return): | |||
| """ | |||
| Find the correct precursor node name in return statement of its parent module. | |||
| Args: | |||
| prec_node (str): The onnx name of the precursor node given. | |||
| in_module_return (list[tuple]): The list of outputs which contains parent module identifier | |||
| and module opt_var_name. | |||
| Return: | |||
| str, correct opt_var_name to be passed in current node. | |||
| """ | |||
| found_return = False | |||
| for ret in in_module_return: | |||
| (md_identifier, input_name_to_use) = ret | |||
| p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) | |||
| # recursive check the p node parent | |||
| parent = p_node_struct | |||
| while not found_return: | |||
| parent = parent.parent_module_struct | |||
| if parent is None: | |||
| break | |||
| if parent.identifier == md_identifier: | |||
| return input_name_to_use | |||
| return None | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block. """ | |||
| left = self.ms_opt_var_name | |||
| @@ -356,15 +323,7 @@ class NodeStruct: | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: | |||
| inputs = [str(tuple(inputs)).replace("\'", "")] | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: | |||
| for _, val in self._fragment.code_setting.op_extra_input.items(): | |||
| inputs.append(str(val)) | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: | |||
| inputs.append(f"self.{self.ms_var_name}_w") | |||
| self.fragment.default_var['inputs'] = inputs | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return left, right | |||
| @@ -394,7 +353,7 @@ class NodeStruct: | |||
| {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) | |||
| self.inputs_in_construct_header[onnx_precursor_node_name] = header_x | |||
| def _check_target_node_internal(self, name: str) -> bool: | |||
| def check_target_node_internal(self, name: str) -> bool: | |||
| """ | |||
| Check given node under the same scope. | |||
| @@ -406,6 +365,9 @@ class NodeStruct: | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| return False | |||
| if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')): | |||
| return False | |||
| if target_nd_struct is None: | |||
| raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) | |||
| return target_nd_struct.scope.path == self.scope.path | |||
| @@ -414,7 +376,7 @@ class NodeStruct: | |||
| def has_successor_node_external(self) -> bool: | |||
| """Check if any successor_node is in external module.""" | |||
| for name in self.successor_nodes_names: | |||
| if not self._check_target_node_internal(name): | |||
| if not self.check_target_node_internal(name): | |||
| return False | |||
| return True | |||
| @@ -423,10 +385,10 @@ class NodeStruct: | |||
| def precursor_nodes_names_external(self) -> list: | |||
| """Return a list of external precursor nodes names.""" | |||
| return [name for name in self.precursor_nodes_names | |||
| if not self._check_target_node_internal(name)] | |||
| if not self.check_target_node_internal(name)] | |||
| @property | |||
| def successor_nodes_names_external(self) -> list: | |||
| """Return a list of external successor nodes names.""" | |||
| return [name for name in self.successor_nodes_names | |||
| if not self._check_target_node_internal(name)] | |||
| if not self.check_target_node_internal(name)] | |||
| @@ -103,25 +103,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| op_name_converter = getattr(converter, GET_OP_NAME) | |||
| params_converter = getattr(converter, GET_OP_PARAMS) | |||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | |||
| settings_converter = getattr(converter, GET_OP_SETTINGS) | |||
| template_generator = getattr(converter, GET_OP_TEMPLATE) | |||
| except (ModuleNotFoundError,) as e: | |||
| # If mapper can not be found, then skip it. | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict(), None, dict() | |||
| return None, None, None, None | |||
| try: | |||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | |||
| converted_params = params_converter(params=params, weights=weights) | |||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| converted_settings = settings_converter(params=params, weights=weights) | |||
| if "input_shape" in converted_params: | |||
| converted_params.pop("input_shape") | |||
| if "output_shape" in converted_params: | |||
| converted_params.pop("output_shape") | |||
| # set to converted_weights to enable weight migration | |||
| _ = weights_converter(weights=weights) if weights else dict() | |||
| code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( | |||
| operation=converter_name, | |||
| converted_params=converted_params, | |||
| raw_params=params, | |||
| weights=weights | |||
| ) | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict(), None, dict() | |||
| code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( | |||
| operation=op_name, | |||
| params=params, | |||
| weights=weights | |||
| ) | |||
| return converter_name, converted_params, converted_settings, converted_weights | |||
| return code_template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| @@ -142,7 +155,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params") | |||
| args = kwargs.get("converted_params", dict()) | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| @@ -464,6 +464,13 @@ class OnnxDataLoader: | |||
| else: | |||
| self._global_context.onnx_node_inputs[node.name] = [input_node_name] | |||
| def _parse_graph(self): | |||
| """Parse ONNX Graph Info For usage in generator.""" | |||
| graph_inputs = [inp.name for inp in self.graph.input] | |||
| graph_outputs = [out.name for out in self.graph.output] | |||
| self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs | |||
| self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | |||
| def initialize(self): | |||
| """Initialize the OnnxDataLoader.""" | |||
| @@ -473,6 +480,9 @@ class OnnxDataLoader: | |||
| log.error(str(err)) | |||
| log.exception(err) | |||
| # Parse ONNX Graph level info | |||
| self._parse_graph() | |||
| # 1. parse all nodes | |||
| self._parse_nodes() | |||
| @@ -217,8 +217,5 @@ class TestMappers: | |||
| def test_mapper(self, params): | |||
| """Test mapper function.""" | |||
| mapper = ONNXToMindSporeMapper() | |||
| converter_name, converted_params, converted_settings, _ = \ | |||
| _, _, _, _ = \ | |||
| mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | |||
| assert params['expected_output']['converter_name'] == converter_name | |||
| assert params['expected_output']['converted_params'] == converted_params | |||
| assert isinstance(converted_settings, Setting) | |||