diff --git a/mindinsight/mindconverter/graph_based_converter/common/__init__.py b/mindinsight/mindconverter/graph_based_converter/common/__init__.py new file mode 100644 index 00000000..5abd50b8 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common instance and utils of graph based converter.""" diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py new file mode 100644 index 00000000..705e10ce --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -0,0 +1,218 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define CodeLine object.""" +import abc + + +class TrainableParams: + """Trainable parameters.""" + + def __init__(self, shape, dtype, reference): + self.param_name = None + self.shape = shape + self.dtype = dtype + self.reference = reference # Weight name in global npy. + + +class CodeSetting: + """Code generation settings.""" + + def __init__(self): + self.output_vars_suffix = [] + self.operation_input_type = None # Construct input type, tensor or list. + self.operation_extra_input = dict() # `values` in original setting dict. + self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor + + +class Fragment(abc.ABC): + """ + Define comment attributes of code generation. + + Args: + operation (str): Operation name in MindSpore. + actual_args (dict): Actual arg values. + settings (namedTuple): Code generation setting. + """ + + def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): + self._operation = operation + self._input_shape = input_shape + self._output_shape = output_shape + self._declared_variable_name = None + self._output_var_name = list() # Output variable name(could be multi-opt). + self._operation_inputs = list() # Index indices the order of input. + self._operation_extra_inputs = settings + self._code_setting = settings + self._formal_args_list = dict() + self._actual_args_list = actual_args # Key is the param_key, value is the corresponding value. + self._node_type = "" + + @property + def code_setting(self): + return self._code_setting + + @property + def node_type(self): + """Node type getter.""" + return self._node_type + + @node_type.setter + def node_type(self, t): + """Node type setter.""" + self._node_type = t + + @property + def operation_extra_inputs(self): + """Getter of extra operation inputs.""" + return self._operation_extra_inputs + + @property + def declared_var_name(self): + """Declared variable name getter.""" + return self._declared_variable_name + + @declared_var_name.setter + def declared_var_name(self, var): + """Setter of declared variable name.""" + self._declared_variable_name = var + + @property + def output_var_name(self) -> str: + """Getter of output variable name.""" + return ", ".join(self._output_var_name) + + @output_var_name.setter + def output_var_name(self, opt_vars): + """ + Output variable name setter. + + Args: + opt_vars (list[str]): Output variable name. + """ + self._output_var_name = opt_vars + + @property + def operation_inputs(self): + """ + Operation getter. + + Returns: + list[Fragment], list of inputs. + """ + return self._operation_inputs + + def update_operation_inputs(self, ipt): + """ + Update operation inputs. + + Args: + ipt (Fragment): Where input comes from. + """ + self._operation_inputs.append(ipt) + + @property + def operation(self): + """ + Operation getter. + + Returns: + str, operation name to be initialized. + """ + return self._operation + + @operation.setter + def operation(self, op: str): + """ + Operation setter. + + Args: + op (str): Operation name. + """ + self._operation = op + + @property + def actual_args(self) -> dict: + """Getter of actual args.""" + return self._actual_args_list + + @property + def formal_args(self) -> dict: + """Get formal args.""" + return self._formal_args_list + + def update_formal_args(self, formal_args: dict): + """ + Update formal args. + + Args: + formal_args (dict): To be updated args. + """ + return self._formal_args_list.update(formal_args) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + +class CodeFragment(Fragment): + """ + Manage the variables related with code generation. + + For single operation type node, the variables in `CodeLine` stands for: + ```python + class Module(nn.Cell): + def __init__ (self, ...): + super(Module, self).__init__() + self. = (, + ) + self. = Tensor(, + dtype=) + + def construct(self, x, ...): + = self.() + ... + return output + ``` + + Args: + operation (str): Operation name in MindSpore. + actual_args (dict): Actual arg values. + settings (namedTuple): Code generation setting. + """ + + def __init__(self, operation, actual_args, settings, input_shape, output_shape, + trainable_params=None): + super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, + input_shape=input_shape, output_shape=output_shape, + settings=settings) + self._trainable_params = dict() # External weights, like Matmul. + self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. + + @property + def trainable_params(self): + return self._trainable_params + + +class ModuleFragment(Fragment): + """Manage module type code variables.""" + + def __init__(self, operation, actual_args, settings, input_shape, output_shape): + super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, + input_shape=input_shape, output_shape=output_shape, + settings=settings) diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py new file mode 100644 index 00000000..5fcbd82e --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define common utils.""" +from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP + + +def is_converted(operation: str): + """ + Whether convert successful. + + Args: + operation (str): Operation name. + + Returns: + bool, true or false. + """ + return operation and SEPARATOR_IN_ONNX_OP not in operation diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py index 44613454..e06e781b 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Hierarchical tree module.""" import re + from mindinsight.mindconverter.common.log import logger as log from .hierarchical_tree import HierarchicalTree from ..third_party_graph.onnx_graph_node import OnnxGraphNode @@ -36,7 +37,6 @@ def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): """ scope_name = node.scope_name new_name = None - parent = "" regex = r"(?P.+/)(?P\w+)" match = re.match(regex, scope_name) parent = match.group("parent") @@ -74,12 +74,13 @@ class HierarchicalTreeFactory: f"Cannot find {node_name}'s input shape." log.error(err_msg) if isinstance(node_inst, OnnxGraphNode): - node_name_with_scope = _tf_model_node_name_reformat( - node_inst, node_name) + node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) node_scope_name[node_name] = node_name_with_scope node_name = node_name_with_scope - tree.insert(node_inst, node_name, node_input, node_output) + node_inst.add_input_and_output_shape(node_input, node_output) + tree.insert(node_inst, node_name) + if node_scope_name: return tree, node_scope_name return tree diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index e67fc33a..a4802ec9 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -25,17 +25,18 @@ from treelib import Tree, Node from mindinsight.mindconverter.common.log import logger as log from .name_mgr import ModuleNameMgr, GlobalVarNameMgr +from ..common.utils import is_converted from ..mapper.base import Mapper from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode -from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig +from ..constant import SEPARATOR_IN_SCOPE +from ..constant import CodeFormatConfig +from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NodeType from ..report_generator import ReportGenerator from ...common.exceptions import NodeTypeNotSupport -GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() - class HierarchicalTree(Tree): """Define hierarchical tree.""" @@ -46,6 +47,8 @@ class HierarchicalTree(Tree): _root_created = False ROOT_LEVEL = 0 + GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() + def __init__(self): super(HierarchicalTree, self).__init__() self._hierarchical_order = dict() @@ -62,6 +65,7 @@ class HierarchicalTree(Tree): self._module_vars = dict() # scope name mapping record for easy node searching self._scope_name_map = dict() + self.code_fragment_recorder = dict() @property def tree_identifier(self): @@ -82,19 +86,15 @@ class HierarchicalTree(Tree): return None return self._nodes[nid] - def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], - node_name: str, input_shape, output_shape): + def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): """ Insert node into hierarchical tree. Args: node_name (str): Node name. node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. - output_shape (tuple): Output tensor shape. - input_shape (tuple): Input tensor shape. """ - node.add_input_and_output_shape(input_shape, output_shape) scopes = node_name.split(SEPARATOR_IN_SCOPE) for idx, scope in enumerate(scopes): parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) @@ -125,10 +125,9 @@ class HierarchicalTree(Tree): tgt_node.precursor_nodes = node.precursor_nodes tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 else NodeType.MODULE).value - tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0] tgt_node.variable_name = self._get_var_name(identifier) self.create_node( - tag=tgt_node.tag, + tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], identifier=identifier, parent=parent, data=tgt_node @@ -276,8 +275,7 @@ class HierarchicalTree(Tree): node.data.replace_with_arg(arg, arg) return node - @staticmethod - def _clear_unused_args(node, used_args): + def _clear_unused_args(self, node, used_args): """ Clear unused args. @@ -290,7 +288,9 @@ class HierarchicalTree(Tree): """ args_in_code = list(node.data.args_in_code.keys()) for arg in args_in_code: - ori_arg = arg.replace(f"_{node.data.variable_name}", "") + ori_arg = arg.replace( + f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" + ) if ori_arg not in used_args: node.data.args_in_code.pop(arg) return node @@ -323,6 +323,8 @@ class HierarchicalTree(Tree): # 1. Generate args for each node in this level. if node.data.node_type == NodeType.MODULE.value: self._create_module_args_and_vars(node, mapper) + if depth == depths[-1]: + self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") # Module merging based on all nodes. self._module_merging() @@ -345,30 +347,29 @@ class HierarchicalTree(Tree): # then assign the created module name to current node, # and delete unused args. module_name = self._created_module[module_key] - nd_inst.data.froze_node_type_and_module_name(node_type, - module_name) + self.code_fragment_recorder[nd_inst.identifier].operation = module_name + self.code_fragment_recorder[nd_inst.identifier].node_type = node_type self._preprocess_node_args(nd_inst, module_key) continue - module_name = nd_inst.data.module_name + module_name = nd_inst.tag + if node_type == NodeType.CLASS.value: module_name = f"{module_name[0].upper()}{module_name[1:]}" # After node_type and module_name is frozen, # then it's unchangeable. module_name = self._module_mgr.get_name(module_name) - nd_inst.data.froze_node_type_and_module_name(node_type, - module_name) + self.code_fragment_recorder[nd_inst.identifier].operation = module_name + self.code_fragment_recorder[nd_inst.identifier].node_type = node_type # 3. Pre-process node args. nd_inst = self._preprocess_node_args(nd_inst, module_key) # 4. Post-process child node args. for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): - self._postprocess_node_args( - self.get_node(scsr_nd_name), module_key) + self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) # 5. Generate code. - snippets.add( - func(nd_inst, nd_inst.data.module_name, module_key)) + snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) code_blocks.extend(snippets) @@ -437,7 +438,7 @@ class HierarchicalTree(Tree): module_list = [] for node_name in node.successors(self.tree_identifier): c_nd = self.get_node(node_name) - operator = c_nd.data.op_in_ms or c_nd.data.module_name + operator = self.code_fragment_recorder[c_nd.identifier].operation if c_nd.data.node_type != NodeType.OPERATION.value: hash_key = c_nd.data.hash_key or self.hash_key(c_nd) @@ -445,14 +446,16 @@ class HierarchicalTree(Tree): operator = self._created_module[hash_key] args = c_nd.data.args_in_code - if c_nd.data.node_type == NodeType.OPERATION.value and \ - not c_nd.data.convert_successful(): + if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( + self.code_fragment_recorder[c_nd.identifier].operation): args.update({"input_shape": c_nd.data.input_shape, "output_shape": c_nd.data.output_shape}) # Generate code statement. - expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" - for k, v in args.items()]) + expr = ", ".join( + [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" + for k, v in args.items()] + ) code_line = f"{operator}({expr})" module_list.append(code_line) @@ -547,14 +550,16 @@ class HierarchicalTree(Tree): if idx != 0: # Get previous node output variable name. - ipt_args_in_construct = self._get_previous_opt_var( - cur_nd_inst, pre_nd_inst) + ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: # Set opt variable name. - opt_arg_in_construct = cur_nd_inst.data.opt_var_name + opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, - output_var=opt_arg_in_construct) + variable_name=self.code_fragment_recorder[ + cur_nd_inst.identifier].declared_var_name, + output_var=opt_arg_in_construct, + code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) return declare, call @@ -588,7 +593,9 @@ class HierarchicalTree(Tree): if e not in pre_nd.successors(self.tree_identifier): while True: if p_nd.identifier in pre_nd.successors(self.tree_identifier): - ipt_lst.append(p_nd.data.opt_var_name) + ipt_lst.append( + f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" + ) break pre_nd_name = p_nd.predecessor(self.tree_identifier) if not pre_nd_name: @@ -597,7 +604,9 @@ class HierarchicalTree(Tree): p_nd = self.get_node(pre_nd_name) continue - ipt_lst.append(p_nd.data.opt_var_name) + ipt_lst.append( + f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" + ) return ipt_lst def _get_previous_opt_var(self, cur_nd, pre_nd): @@ -619,12 +628,11 @@ class HierarchicalTree(Tree): cur_nd = self.get_node(p_nd[0]) return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) - def hash_key(self, node, depth: int = 0): + def hash_key(self, node): """ Generate hash key for each node. Args: - depth (int): Recursion depth. node (Node): Node. Returns: @@ -633,13 +641,17 @@ class HierarchicalTree(Tree): scsr_topo_order = [] for s in node.successors(self.tree_identifier): cur_nd = self.get_node(s) - if cur_nd.data.hash_key: - scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]") - continue if cur_nd.data.node_type in {NodeType.MODULE.value, NodeType.FUNC.value, NodeType.CLASS.value}: - scsr_topo_order.append(self.hash_key(cur_nd, depth + 1)) + if cur_nd.data.hash_key: + scsr_topo_order.append(f"({cur_nd.data.hash_key})") + continue + + raise ValueError("Current node doesn't have hash key.") + + if cur_nd.data.hash_key: + scsr_topo_order.append(cur_nd.data.hash_key) continue unique_key = "->".join(scsr_topo_order) node.data.hash_key = unique_key @@ -675,12 +687,11 @@ class HierarchicalTree(Tree): """ # All args and value pair in current node module. module_args = dict() - module_settings = dict() module_key = self.hash_key(node) created = False if module_key not in self._vars_mgr_in_module: - self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR + self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR self._module_vars[module_key] = [] else: created = True @@ -688,33 +699,29 @@ class HierarchicalTree(Tree): # Sub-modules in the module could have arg name conflicts. for idx, successor_name in enumerate(node.successors(self.tree_identifier)): nd_inst = self.get_node(successor_name) - # Generate variable name here, then - # to generate args. + # Generation of params must behind variable assigment. if created: - nd_inst.data.variable_name = self._module_vars[module_key][idx] + variable_name = self._module_vars[module_key][idx] else: - variable_name = nd_inst.data.op_name or nd_inst.data.module_name - variable_name = self._vars_mgr_in_module[module_key].get_name( - variable_name) - nd_inst.data.variable_name = variable_name + variable_name = nd_inst.data.op_name or nd_inst.tag + variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) - # Generation of params must behind variable assigment. - nd_inst.data.param_transform(mapper) + code_fragment = nd_inst.data.param_transform(mapper, variable_name) + code_fragment.declared_var_name = variable_name + self.code_fragment_recorder[nd_inst.identifier] = code_fragment module_args.update(nd_inst.data.args_in_code) - module_settings.update(nd_inst.data.settings_in_code) if not created: - self._module_vars[module_key].append( - nd_inst.data.variable_name) + self._module_vars[module_key].append(variable_name) node.data.args_in_code = module_args # Collect module args of `module_key`. if module_key not in self._merged_module: - self._merged_module[module_key] = [node.data.args_in_code] + self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] else: - self._merged_module[module_key].append(node.data.args_in_code) + self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) @staticmethod def _create_operation_args(node, mapper): diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py index 62d8aab4..2811f762 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py @@ -63,6 +63,10 @@ START_IDX = 0 class GlobalVarNameMgr: """Global variable name mgr.""" + def __init__(self): + global_op_namespace.clear() + global_var_namespace.clear() + @staticmethod def _get_name(name): """Deal with op name.""" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index c1dba8bc..e0871852 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -87,7 +87,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): module_name = TABLE.get(op_name) if not module_name: - return None, dict(), dict() + return None, dict(), None, dict() pos = module_name.rfind(".") try: @@ -101,7 +101,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): # If mapper can not be found, then skip it. err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), dict() + return None, dict(), None, dict() try: converter_name = op_name_converter( @@ -110,13 +110,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): converted_weights = weights_converter( weights=weights) if weights else dict() converted_params.update(converted_weights) - converted_settings = settings_converter(params=params) + converted_settings = settings_converter(params=params, weights=weights) except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), dict() + return None, dict(), None, dict() - return converter_name, converted_params, converted_settings + return converter_name, converted_params, converted_settings, converted_weights @staticmethod def _operation_name_in_ms(*args, **kwargs): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py b/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py new file mode 100644 index 00000000..08a5f98c --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operation mapping setting.""" +from collections import namedtuple +import numpy as np + +from mindinsight.mindconverter.graph_based_converter.constant import InputType + +Tensor = namedtuple("Tensor", ["shape", "dtype", "reference"]) + +Setting = namedtuple("Setting", ["opt_vars_suffix", + "op_ipt_type", + "op_extra_input", + "op_extra_tensor"]) +Setting.__new__.__defaults__ = ("_opt", InputType.TENSOR.value, dict(), None) + + +def get_dtype(tensor: np.ndarray): + """Get tensor dtype.""" + if tensor.dtype == np.float16: + return "mindspore.float16" + return "mindspore.float32" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index 2f8d0e57..5d02b9b2 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class BatchNormMapper(ONNXToMindSporeMapper): @@ -39,4 +40,4 @@ class BatchNormMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 6513a7c8..9a9a2200 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -16,6 +16,7 @@ import re import numpy as np from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting def _convert_padding(**kwargs): @@ -35,6 +36,7 @@ def _convert_padding(**kwargs): class ConvMapper(ONNXToMindSporeMapper): """Conv2d mapper.""" + @staticmethod def convert_params_torch(**kwargs): """Convert params from PyTorch to MindSpore""" @@ -148,4 +150,4 @@ class ConvMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index 18f68716..2f4eb387 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class DenseMapper(ONNXToMindSporeMapper): @@ -41,4 +42,4 @@ class DenseMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py index 679ea812..024cf499 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class FlattenMapper(ONNXToMindSporeMapper): @@ -33,4 +34,4 @@ class FlattenMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py index c29a971d..29bc8550 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class GlobalPoolMapper(ONNXToMindSporeMapper): @@ -25,8 +26,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): op_name = 'nn.AvgPool{}d' else: op_name = 'nn.MaxPool{}d' - dim = 1 if len(kwargs['params']['input_shape']) == 3\ - else 2 + dim = 1 if len(kwargs['params']['input_shape']) == 3 else 2 return op_name.format(dim) @staticmethod @@ -49,4 +49,4 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index 2d2782cf..603a4fd8 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting, Tensor, get_dtype class MatMulMapper(ONNXToMindSporeMapper): @@ -33,4 +34,12 @@ class MatMulMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + weights = kwargs.get("weights") + if not weights: + return Setting() + tensor, ref = None, "" + for t_name, t_value in weights.items(): + tensor = t_value + ref = t_name + return Setting(op_extra_tensor=Tensor(shape=tensor.shape, + dtype=get_dtype(tensor), reference=ref)) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py index e0ff4225..4c3e0715 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting def _padding_format_convert(padding: list): @@ -77,4 +78,4 @@ class PadMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py index 1c248a75..b33ed715 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class PoolMapper(ONNXToMindSporeMapper): @@ -49,4 +50,4 @@ class PoolMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py index b5a24717..e89052ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ReLUMapper(ONNXToMindSporeMapper): @@ -45,4 +46,4 @@ class ReLUMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py index fbbe781a..be029109 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class SoftmaxMapper(ONNXToMindSporeMapper): @@ -37,4 +38,4 @@ class SoftmaxMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py index 7b6dea75..83808984 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting, Tensor, get_dtype class AddMapper(ONNXToMindSporeMapper): @@ -33,4 +34,12 @@ class AddMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + weights = kwargs.get("weights") + if not weights: + return Setting() + tensor, ref = None, "" + for t_name, t_value in weights.items(): + tensor = t_value + ref = t_name + return Setting(op_extra_tensor=Tensor(shape=tensor.shape, + dtype=get_dtype(tensor), reference=ref)) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py index b0a32a9e..eb1205f9 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py @@ -15,6 +15,7 @@ """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.constant import InputType from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ConcatMapper(ONNXToMindSporeMapper): @@ -36,4 +37,4 @@ class ConcatMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): input_type = InputType.LIST.value - return {'input_type': input_type} + return Setting(op_ipt_type=input_type) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py index 68623457..239d07a6 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ReduceMeanMapper(ONNXToMindSporeMapper): @@ -40,4 +41,4 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) else: axis = tuple() - return {'values': {'axis': axis}} + return Setting(op_extra_input={'axis': axis}) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py index cb51153c..d294d9d1 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class TransposeMapper(ONNXToMindSporeMapper): @@ -40,4 +41,4 @@ class TransposeMapper(ONNXToMindSporeMapper): perm = tuple(perm) converted_params['input_perm'] = perm - return {'values': converted_params} + return Setting(op_extra_input=converted_params) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index 811f1f8a..34e63569 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -15,10 +15,13 @@ """Define graph entity.""" import abc from collections import OrderedDict +from copy import deepcopy from mindinsight.mindconverter.common.log import logger as log -from ..constant import SEPARATOR_IN_ONNX_OP +from ..common.code_fragment import CodeFragment +from ..constant import NodeType, InputType from ..mapper.base import Mapper +from ...common.exceptions import NodeInputTypeNotSupport class GraphParser(metaclass=abc.ABCMeta): @@ -287,26 +290,10 @@ class GraphNode(abc.ABC): self._op_params = dict() self._scope_name = None self._op_shape = None - # Operation in mindspore. - self._op_in_ms = None - # Params in mindspore. - self._params_in_ms = dict() - # Settings in mindspore. - self._settings_in_ms = dict() # Node type of current node, e.g. class, module, operation. self._node_type = None - # Tag name on tree. - self._tag_on_tree = None # Function, class or operation needed args. self._args_in_code = dict() - # Operation needed settings. - self._settings_in_code = dict() - # Variable name declared in init block. - self._variable_name = None - # Output variable name declared in construct block. - self._opt_var_name = None - # Function or class name in code. - self._module_name = None # Unique key of node. self._hash_key = None # Input shape of current op. @@ -317,37 +304,18 @@ class GraphNode(abc.ABC): self._weight = None @property - def opt_var_name(self): + def weight(self): + return self._weight + + @staticmethod + def get_opt_var_name(variable_name): """ Output variable name. Returns: str, variable name. """ - return f"{self.variable_name}_opt" - - @opt_var_name.setter - def opt_var_name(self, v): - """ - Set variable name. - - Args: - v (str): Name. - - """ - self._opt_var_name = v - - @property - def op_in_ms(self): - """ - Operation in mindspore. - - Returns: - str, operation name. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms: - return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".") - return self._op_in_ms + return f"{variable_name}_opt" @property def args_in_code(self): @@ -370,27 +338,6 @@ class GraphNode(abc.ABC): """ self._args_in_code = args - @property - def settings_in_code(self): - """ - Settings in code. - - Returns: - dict, settings. - """ - return self._settings_in_code - - @settings_in_code.setter - def settings_in_code(self, settings): - """ - Settings in code. - - Args: - settings(dict): Settings. - - """ - self._settings_in_code = settings - @property def input_shape(self): """ @@ -411,16 +358,6 @@ class GraphNode(abc.ABC): """ return self._opt_shape - @property - def tag(self): - """Tag on hierarchical tree.""" - return self._tag_on_tree - - @tag.setter - def tag(self, t): - """Tag on hierarchical tree.""" - self._tag_on_tree = t - def is_empty(self): """ Whether is empty. @@ -536,7 +473,7 @@ class GraphNode(abc.ABC): """Replace actual parameter with formal parameter.""" @abc.abstractmethod - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """Get arg name for func or class.""" @abc.abstractmethod @@ -553,13 +490,8 @@ class GraphNode(abc.ABC): def real_name(self, **kwargs): """Setter of `real_name`.""" - @property - @abc.abstractmethod - def variable_name(self): - """Getter of `variable_name`.""" - @abc.abstractmethod - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): """Graph node to MindSpore code.""" @abc.abstractmethod @@ -570,40 +502,86 @@ class GraphNode(abc.ABC): def add_input_and_output_shape(self, input_shape, output_shape): """Add the node input shape.""" - @abc.abstractmethod - def froze_node_type_and_module_name(self, node_type, module_name): - """Make node_type can not be changed.""" + @staticmethod + def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): + """ + Generate input with args and settings in construct. - @abc.abstractmethod - def convert_successful(self): - """Whether convert successful.""" + Args: + ipt_args_in_construct (str): Input args in construct. + settings (Setting): Settings in operator. + + Returns: + str, args of each node in generated construct statement. + """ + if settings and settings.op_ipt_type: + input_type = settings.op_ipt_type + if input_type == InputType.TENSOR.value: + ipt_args_settings_in_construct = ipt_args_in_construct + elif input_type == InputType.LIST.value: + ipt_args_settings_in_construct = f"({ipt_args_in_construct})" + else: + raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") + else: + ipt_args_settings_in_construct = ipt_args_in_construct + + if settings and settings.op_extra_input: + settings_value = settings.op_extra_input + if settings_value: + settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) + ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) - def param_transform(self, mapper: Mapper): + return ipt_args_settings_in_construct + + def param_transform(self, mapper: Mapper, variable_name): """ - Transform param in pytorch operation into mindspore. + Transform param in PyTorch operation into MindSpore. Args: + variable_name (str): Variable name. mapper (ONNXToMindSporeMapper): Mapper between onnx operation - and mindspore. + and MindSpore. Returns: dict, transformed params. """ - import copy - params = copy.deepcopy(self._op_params) + if self._node_type != NodeType.OPERATION.value: + args = deepcopy(self._args_in_code) + self._args_in_code = dict() + for arg, value in args.items(): + self._args_in_code[self._get_arg_name(arg, variable_name)] = value + return CodeFragment(operation="", actual_args=args, settings=None, + input_shape=self.input_shape, output_shape=self.output_shape) + + if self.transformed: + raise ValueError("Already transformed.") + + params = deepcopy(self._op_params) params.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) - op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, - params=params, - weights=self._weight) - if op_name_in_mindspore: - self._op_in_ms = op_name_in_mindspore - self._params_in_ms = ms_params - self._settings_in_ms = ms_settings + ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, + params=params, + weights=self._weight) + + if ms_op: + code_fragment = CodeFragment(operation=ms_op, + actual_args=ms_params, + settings=ms_settings, + input_shape=self.input_shape, + output_shape=self.output_shape, + trainable_params=ms_weights) else: - self._op_in_ms = self._op_name - self._params_in_ms = self._op_params - self._settings_in_ms = dict() + code_fragment = CodeFragment(operation=self._op_name, + actual_args=self._op_params, + settings=None, + input_shape=self.input_shape, + output_shape=self.output_shape, + trainable_params=self._weight) + + for arg, value in code_fragment.actual_args.items(): + self._args_in_code[self._get_arg_name(arg, variable_name)] = value + + self.transformed = True - return self._op_in_ms, self._params_in_ms, self._settings_in_ms + return code_fragment diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py index 5a6644cf..9ec692c3 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py @@ -38,7 +38,6 @@ class PyTorchGraphParser(GraphParser): error = FileNotFoundError("`model_path` must be assigned with " "an existed file path.") log.error(str(error)) - log.exception(error) raise error try: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py index be93ad03..e92d7035 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py @@ -21,24 +21,18 @@ from ..constant import SEPARATOR_IN_SCOPE, NodeType class InputNode(GraphNode): """ - Pytorch Input Node. + PyTorch Input Node. Args: input_shape: Input shape of module. """ - def convert_successful(self): - """ - Whether convert successful. - - Returns: - bool, true or false. - """ - return False + def _get_arg_name(self, arg, variable_name): + raise NotImplementedError() - def froze_node_type_and_module_name(self, node_type, module_name): - pass + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): + raise NotImplementedError() def _get_raw_params(self, node): pass @@ -56,9 +50,6 @@ class InputNode(GraphNode): def replace_with_arg(self, src_arg, tgt_arg): pass - def _get_arg_name(self, arg): - pass - def add_input_and_output_shape(self, input_shape, output_shape): pass @@ -116,15 +107,8 @@ class InputNode(GraphNode): def real_name(self): return - @property - def variable_name(self): - return - def to_ir(self): """ No need to implement for now. """ raise NotImplementedError() - - def to_code(self, ipt_args_in_construct: str, output_var: str): - raise NotImplementedError() diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 27e52c68..f4befbc8 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -22,7 +22,6 @@ from .onnx_graph_node import OnnxGraphNode from .graph_parser import TFGraphParser from .onnx_utils import OnnxDataLoader - NONE_SCOPE_OP = { "onnx::Add": "Add", "onnx::Flatten": "Flatten", diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py index 4b4383e0..62cab356 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py @@ -13,14 +13,13 @@ # limitations under the License. # ============================================================================== """Define ONNX graph node.""" +from importlib import import_module -from copy import deepcopy from .base import GraphNode +from ..common.utils import is_converted -from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ - SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType -from ..mapper.base import Mapper -from ...common.exceptions import NodeInputTypeNotSupport +from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ + SEPARATOR_IN_ONNX_OP class OnnxGraphNode(GraphNode): @@ -39,16 +38,13 @@ class OnnxGraphNode(GraphNode): self._op_params = self._get_raw_params(node.raw_node) if node else None self._op_name = "onnx::" + node.op_type if node else None self._scope_name = node.scope_name if node else None - self._opt_var_name = None - self._variable_name = self._extract_var_name(self._scope_name) - self._module_name = None self._weight = weight def clear_args_of_declaration(self): """Clear `self._args_in_code`.""" self._args_in_code = dict() - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """ Get arg name. @@ -58,7 +54,7 @@ class OnnxGraphNode(GraphNode): Returns: str, arg name in function or class declaration. """ - return f"{arg}_{self._variable_name}" + return f"{arg}_{variable_name}" @property def hash_key(self): @@ -84,51 +80,6 @@ class OnnxGraphNode(GraphNode): """ self._hash_key = h - @property - def variable_name(self): - """ - Variable name. - - Returns: - str, variable name declared in init. - """ - return self._variable_name - - @variable_name.setter - def variable_name(self, v): - """ - Setter of variable name. - - Args: - v (str): Variable name. - """ - self._variable_name = v - - @property - def module_name(self): - """ - Module name. - - Returns: - str, module name. - """ - if not self._module_name_frozen: - module_name = self.tag - return module_name - - return self._module_name - - def _froze_module_name(self, m): - """ - Once module_name is set, then it's unchangeable. - - Args: - m (str): Module name. - """ - if not self._module_name_frozen: - self._module_name = m - self._module_name_frozen = True - @property def op_name(self): """ @@ -154,15 +105,13 @@ class OnnxGraphNode(GraphNode): self._ipt_shape = input_shape self._opt_shape = output_shape - def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args): + def _add_tensor_args_to_code(self, op_name: str, settings, declare, args, variable_name): """ Add nn used tensors to args in init and construct blocks. Args: op_name (str): Add the tensor to args if the current node has this - op_name. - t_identifier (str): The unique string appeared in the target tensor - name. + op_name. declare (str): Declare statement generated in to_code(). args (str): Args statement generated in to_code(). @@ -172,103 +121,68 @@ class OnnxGraphNode(GraphNode): """ if not self._op_name == op_name: return declare, args - declare_list = [] - tensor = None - # find target tensor - for t_name, t_value in self._weight.items(): - if t_identifier in t_name: - tensor = t_value - break - if tensor is None: + if not settings or not settings.op_extra_tensor: return declare, args - declare_list.append(declare) - declare_t = f"self.{self._variable_name}_w = Tensor(" \ - f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)" + declare_list = [declare] + declare_t = f"self.{variable_name}_w = Tensor(" \ + f"np.random.uniform(0, 1, {str(settings.op_extra_tensor.shape)}), " \ + f"{settings.op_extra_tensor.dtype})" declare_list.append(declare_t) - args += f", self.{self._variable_name}_w" + args += f", self.{variable_name}_w" return declare_list, args - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, + code_fragment): """ Generate statements. Args: + variable_name (str): Variable name. ipt_args_in_construct (str): Args of input. output_var (str): Output variable name in construct. + code_fragment (CodeFragment): CodeFragment instance. Returns: Union[str, str], declare in init and call in construct. """ - operator = self.op_in_ms or self.module_name - self._opt_var_name = output_var + operator = code_fragment.operation args = self.args_in_code - settings = self.settings_in_code - if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): + settings = code_fragment.code_setting + + if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): args.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) if self._node_type == NodeType.OPERATION.value: - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) - ipt_args_settings_in_construct = \ - self._generate_ipt_args_settings_in_construct( - ipt_args_in_construct, - settings) + ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( + ipt_args_in_construct, settings) else: # When it's type is module, class or func, # it's not necessary to replace var. - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) ipt_args_settings_in_construct = ipt_args_in_construct - declare = f"self.{self._variable_name} = {operator}({expr})" + + if SEPARATOR_IN_ONNX_OP in operator: + operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") + + declare = f"self.{variable_name} = {operator}({expr})" # Extra Tensor generator for nn.MatMul declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( - 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) + 'onnx::MatMul', settings, declare, ipt_args_settings_in_construct, variable_name) # Extra Tensor generator for onnx::Add declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( - 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) + 'onnx::Add', settings, declare, ipt_args_settings_in_construct, variable_name) - call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" + call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" return declare, call - @staticmethod - def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): - """ - Generate input with args and settings in construct. - - Args: - ipt_args_in_construct(str): Input args in construct. - settings(dict): Settings in operator. - - Returns: - str, args of each node in generated construct statement. - """ - if settings.get('input_type'): - input_type = settings['input_type'] - if input_type == InputType.TENSOR.value: - ipt_args_settings_in_construct = ipt_args_in_construct - elif input_type == InputType.LIST.value: - ipt_args_settings_in_construct = f"({ipt_args_in_construct})" - else: - raise NodeInputTypeNotSupport( - f"Input type[{input_type}] is not supported now.") - else: - ipt_args_settings_in_construct = ipt_args_in_construct - - if settings.get('values'): - settings_value = settings['values'] - if settings_value: - settings_in_construct = ', '.join( - [f"{setting_val}" for _, setting_val in settings_value.items()]) - ipt_args_settings_in_construct = ', '.join( - (ipt_args_settings_in_construct, settings_in_construct)) - - return ipt_args_settings_in_construct - def to_ir(self): """No need to implement for now.""" raise NotImplementedError @@ -284,7 +198,7 @@ class OnnxGraphNode(GraphNode): Returns: dict, raw params. """ - import onnx + onnx = import_module("onnx") raw_params = dict() @@ -318,62 +232,3 @@ class OnnxGraphNode(GraphNode): var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( RIGHT_BUCKET, "") return var - - def param_transform(self, mapper: Mapper): - """ - Transform tensorflow params into mindspore. - - Args: - mapper (Mapper): Mapper of params. - - """ - if self._node_type != NodeType.OPERATION.value: - args = deepcopy(self._args_in_code) - self._args_in_code = dict() - for arg, value in args.items(): - self._args_in_code[self._get_arg_name(arg)] = value - return None, None - - if not self.transformed: - _, _, _ = super(OnnxGraphNode, self).param_transform(mapper) - - for arg, value in self._params_in_ms.items(): - self._args_in_code[self._get_arg_name(arg)] = value - - for arg, value in self._settings_in_ms.items(): - self._settings_in_code[arg] = value - - self.transformed = True - - return self._op_in_ms, self._params_in_ms, self._settings_in_ms - - def froze_node_type_and_module_name(self, node_type, module_name): - """ - Froze node type and module name. - - After node_type is frozen, then the `module_name` - will be affected when `node_type` is `class`. - Thus, this line must be placed before `nd_inst.data.module_name`. - - Args: - module_name: Modified module name. - node_type (str): Node type, class of func. - - """ - if not self._type_frozen: - self._node_type = node_type - self._type_frozen = True - - if not self._module_name_frozen: - self._froze_module_name(module_name) - - def convert_successful(self): - """ - Whether convert successfully. - - Returns: - bool, true or false. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: - return True - return False diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 00349046..348bc658 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -87,7 +87,8 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None inputs_as_nchw=None ) opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map') - opt_map.pop(('Conv', 'BatchNormalization')) + if ('Conv', 'BatchNormalization') in opt_map: + opt_map.pop(('Conv', 'BatchNormalization')) onnx_graph = optimizer.optimize_graph(g) model_proto = onnx_graph.make_model("converted from {}".format(model_path)) @@ -228,8 +229,7 @@ class OnnxNode(BaseNode): """ def __init__(self, raw_node): - super(OnnxNode, self).__init__( - node_name=raw_node.name, op_type=raw_node.op_type) + super(OnnxNode, self).__init__(node_name=raw_node.name, op_type=raw_node.op_type) self.raw_node = raw_node self.params = ParamsAttribute(raw_node.attribute, raw_node) self.scope_name = None diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index 0e4e6164..ae75de5c 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -99,8 +99,8 @@ class PyTorchGraph(Graph): for item in input_shape: if not isinstance(item, int): - err_msg = f"Only support model with one input now, " \ - f"and each shape value in `input_shape` should be int." + err_msg = "Only support model with one input now, " \ + "and each shape value in `input_shape` should be int." log.error(err_msg) raise ValueError(err_msg) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py index 4d6a9428..4153a9ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py @@ -13,14 +13,11 @@ # limitations under the License. # ============================================================================== """Define PyTorch graph node.""" -from copy import deepcopy - from .base import GraphNode +from ..common.utils import is_converted from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ - SEPARATOR_IN_ONNX_OP, InputType -from ..mapper.base import Mapper -from ...common.exceptions import NodeInputTypeNotSupport + SEPARATOR_IN_ONNX_OP class PyTorchGraphNode(GraphNode): @@ -40,9 +37,6 @@ class PyTorchGraphNode(GraphNode): self._op_params = self._get_raw_params(node) self._op_name = node.kind() if node else None self._scope_name = node.scopeName() if node else None - self._opt_var_name = None - self._variable_name = self._extract_var_name(self._scope_name) - self._module_name = None self._weight = weight def clear_args_of_declaration(self): @@ -51,7 +45,7 @@ class PyTorchGraphNode(GraphNode): """ self._args_in_code = dict() - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """ Get arg name. @@ -61,7 +55,7 @@ class PyTorchGraphNode(GraphNode): Returns: str, arg name in function or class declaration. """ - return f"{arg}_{self._variable_name}" + return f"{arg}_{variable_name}" @property def hash_key(self): @@ -88,53 +82,6 @@ class PyTorchGraphNode(GraphNode): """ self._hash_key = h - @property - def variable_name(self): - """ - Variable name. - - Returns: - str, variable name declared in init. - """ - return self._variable_name - - @variable_name.setter - def variable_name(self, v): - """ - Setter of variable name. - - Args: - v (str): Variable name. - - """ - self._variable_name = v - - @property - def module_name(self): - """ - Module name. - - Returns: - str, module name. - """ - if not self._module_name_frozen: - module_name = self.tag - return module_name - - return self._module_name - - def _froze_module_name(self, m): - """ - Once module_name is set, then it's unchangeable. - - Args: - m (str): Module name. - - """ - if not self._module_name_frozen: - self._module_name = m - self._module_name_frozen = True - @property def op_name(self): """ @@ -172,72 +119,47 @@ class PyTorchGraphNode(GraphNode): self._ipt_shape = input_shape self._opt_shape = output_shape - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): """ Generate statements. Args: + variable_name (str): Variable name. ipt_args_in_construct (str): Args of input. output_var (str): Output variable name in construct. + code_fragment (CodeFragment): CodeFragment instance. Returns: Union[str, str], declare in init and call in construct. """ - operator = self.op_in_ms or self.module_name - self._opt_var_name = output_var + operator = code_fragment.operation args = self.args_in_code - settings = self.settings_in_code + settings = code_fragment.code_setting - if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): + if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): args.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) if self._node_type == NodeType.OPERATION.value: - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) - ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, - settings) + ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( + ipt_args_in_construct, settings) else: # When it's type is module, class or func, # it's not necessary to replace var. - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) ipt_args_settings_in_construct = ipt_args_in_construct - declare = f"self.{self._variable_name} = {operator}({expr})" - call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" - - return declare, call - - @staticmethod - def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): - """ - Generate input with args and settings in construct. - - Args: - ipt_args_in_construct(str): input args in construct. - settings(dict): settings in operator. - - """ - if settings.get('input_type'): - input_type = settings['input_type'] - if input_type == InputType.TENSOR.value: - ipt_args_settings_in_construct = ipt_args_in_construct - elif input_type == InputType.LIST.value: - ipt_args_settings_in_construct = f"({ipt_args_in_construct})" - else: - raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") - else: - ipt_args_settings_in_construct = ipt_args_in_construct + if SEPARATOR_IN_ONNX_OP in operator: + operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") - if settings.get('values'): - settings_value = settings['values'] - if settings_value: - settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) - ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) + declare = f"self.{variable_name} = {operator}({expr})" + call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" - return ipt_args_settings_in_construct + return declare, call def to_ir(self): """ @@ -288,62 +210,3 @@ class PyTorchGraphNode(GraphNode): var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( RIGHT_BUCKET, "") return var - - def param_transform(self, mapper: Mapper): - """ - Transform torch params into mindspore. - - Args: - mapper (Mapper): Mapper of params. - - """ - if self._node_type != NodeType.OPERATION.value: - args = deepcopy(self._args_in_code) - self._args_in_code = dict() - for arg, value in args.items(): - self._args_in_code[self._get_arg_name(arg)] = value - return None, None, None - - if not self.transformed: - _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) - - for arg, value in self._params_in_ms.items(): - self._args_in_code[self._get_arg_name(arg)] = value - - for arg, value in self._settings_in_ms.items(): - self._settings_in_code[arg] = value - - self.transformed = True - - return self._op_in_ms, self._params_in_ms, self._settings_in_ms - - def froze_node_type_and_module_name(self, node_type, module_name): - """ - Froze node type and module name. - - After node_type is frozen, then the `module_name` - will be affected when `node_type` is `class`. - Thus, this line must be placed before `nd_inst.data.module_name`. - - Args: - module_name: Modified module name. - node_type (str): Node type, class of func. - - """ - if not self._type_frozen: - self._node_type = node_type - self._type_frozen = True - - if not self._module_name_frozen: - self._froze_module_name(module_name) - - def convert_successful(self): - """ - Whether convert successfully. - - Returns: - bool, true or false. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: - return True - return False diff --git a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py index 7f5987fb..45775f86 100644 --- a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py +++ b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py @@ -42,7 +42,7 @@ class TestHierarchicalTree: get_raw_params.return_value = [] tree = HierarchicalTree() pt_node = PyTorchGraphNode() - tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112)) + tree.insert(pt_node, 'ResNet') assert tree.root == 'ResNet' def test_remove(self): diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py index 8f8d21d3..1dd6aa51 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py @@ -17,11 +17,13 @@ import numpy as np import pytest from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from tests.utils import mindspore class TestMappers: """Test Mappers.""" + @pytest.mark.parametrize('params', [{ 'input': {'op_name': 'onnx::Conv', 'params': {'dilations': [1, 1], @@ -38,7 +40,7 @@ class TestMappers: 'pad_mode': '\"pad\"', 'dilation': (1, 1), 'group': 1}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Conv', 'params': {'dilations': [1, 1], @@ -55,7 +57,7 @@ class TestMappers: 'pad_mode': '\"valid\"', 'dilation': (1, 1), 'group': 1}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Gemm', 'params': dict(), @@ -65,7 +67,7 @@ class TestMappers: 'converted_params': {'in_channels': 3, 'out_channels': 10, 'has_bias': True}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::BatchNormalization', 'params': {'epsilon': 1e-5, @@ -76,14 +78,14 @@ class TestMappers: 'converted_params': {'num_features': 6, 'eps': 1e-5, 'momentum': 0.9}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Relu', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::MaxPool', 'params': {'kernel_shape': [3, 3], @@ -94,7 +96,7 @@ class TestMappers: 'converted_params': {'kernel_size': (3, 3), 'stride': (2, 2), 'pad_mode': '"same"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::AveragePool', 'params': {'kernel_shape': [3, 3], @@ -105,7 +107,7 @@ class TestMappers: 'converted_params': {'kernel_size': (3, 3), 'stride': (2, 2), 'pad_mode': '"same"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::GlobalAveragePool', 'params': {'input_shape': (1, 3, 10, 10), @@ -113,21 +115,21 @@ class TestMappers: 'weights': ''}, 'expected_output': {'converter_name': 'nn.AvgPool2d', 'converted_params': {'kernel_size': (10, 10)}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Flatten', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.Flatten', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Add', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'P.TensorAdd', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -137,7 +139,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '\"CONSTANT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -146,7 +148,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '\"REFLECT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -156,7 +158,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -165,7 +167,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::ReduceMean', 'params': {'keepdims': 0, @@ -196,14 +198,14 @@ class TestMappers: 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU6', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Clip', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Clip', 'params': {'max': 3, @@ -211,13 +213,13 @@ class TestMappers: 'weights': dict()}, 'expected_output': {'converter_name': None, 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }]) def test_mapper(self, params): """Test mapper function.""" mapper = ONNXToMindSporeMapper() - converter_name, converted_params, converted_settings = \ + converter_name, converted_params, converted_settings, _ = \ mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) assert params['expected_output']['converter_name'] == converter_name assert params['expected_output']['converted_params'] == converted_params - assert params['expected_output']['converted_settings'] == converted_settings + assert isinstance(converted_settings, Setting)