| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Common instance and utils of graph based converter.""" | |||||
| @@ -0,0 +1,218 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define CodeLine object.""" | |||||
| import abc | |||||
| class TrainableParams: | |||||
| """Trainable parameters.""" | |||||
| def __init__(self, shape, dtype, reference): | |||||
| self.param_name = None | |||||
| self.shape = shape | |||||
| self.dtype = dtype | |||||
| self.reference = reference # Weight name in global npy. | |||||
| class CodeSetting: | |||||
| """Code generation settings.""" | |||||
| def __init__(self): | |||||
| self.output_vars_suffix = [] | |||||
| self.operation_input_type = None # Construct input type, tensor or list. | |||||
| self.operation_extra_input = dict() # `values` in original setting dict. | |||||
| self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor | |||||
| class Fragment(abc.ABC): | |||||
| """ | |||||
| Define comment attributes of code generation. | |||||
| Args: | |||||
| operation (str): Operation name in MindSpore. | |||||
| actual_args (dict): Actual arg values. | |||||
| settings (namedTuple): Code generation setting. | |||||
| """ | |||||
| def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | |||||
| self._operation = operation | |||||
| self._input_shape = input_shape | |||||
| self._output_shape = output_shape | |||||
| self._declared_variable_name = None | |||||
| self._output_var_name = list() # Output variable name(could be multi-opt). | |||||
| self._operation_inputs = list() # Index indices the order of input. | |||||
| self._operation_extra_inputs = settings | |||||
| self._code_setting = settings | |||||
| self._formal_args_list = dict() | |||||
| self._actual_args_list = actual_args # Key is the param_key, value is the corresponding value. | |||||
| self._node_type = "" | |||||
| @property | |||||
| def code_setting(self): | |||||
| return self._code_setting | |||||
| @property | |||||
| def node_type(self): | |||||
| """Node type getter.""" | |||||
| return self._node_type | |||||
| @node_type.setter | |||||
| def node_type(self, t): | |||||
| """Node type setter.""" | |||||
| self._node_type = t | |||||
| @property | |||||
| def operation_extra_inputs(self): | |||||
| """Getter of extra operation inputs.""" | |||||
| return self._operation_extra_inputs | |||||
| @property | |||||
| def declared_var_name(self): | |||||
| """Declared variable name getter.""" | |||||
| return self._declared_variable_name | |||||
| @declared_var_name.setter | |||||
| def declared_var_name(self, var): | |||||
| """Setter of declared variable name.""" | |||||
| self._declared_variable_name = var | |||||
| @property | |||||
| def output_var_name(self) -> str: | |||||
| """Getter of output variable name.""" | |||||
| return ", ".join(self._output_var_name) | |||||
| @output_var_name.setter | |||||
| def output_var_name(self, opt_vars): | |||||
| """ | |||||
| Output variable name setter. | |||||
| Args: | |||||
| opt_vars (list[str]): Output variable name. | |||||
| """ | |||||
| self._output_var_name = opt_vars | |||||
| @property | |||||
| def operation_inputs(self): | |||||
| """ | |||||
| Operation getter. | |||||
| Returns: | |||||
| list[Fragment], list of inputs. | |||||
| """ | |||||
| return self._operation_inputs | |||||
| def update_operation_inputs(self, ipt): | |||||
| """ | |||||
| Update operation inputs. | |||||
| Args: | |||||
| ipt (Fragment): Where input comes from. | |||||
| """ | |||||
| self._operation_inputs.append(ipt) | |||||
| @property | |||||
| def operation(self): | |||||
| """ | |||||
| Operation getter. | |||||
| Returns: | |||||
| str, operation name to be initialized. | |||||
| """ | |||||
| return self._operation | |||||
| @operation.setter | |||||
| def operation(self, op: str): | |||||
| """ | |||||
| Operation setter. | |||||
| Args: | |||||
| op (str): Operation name. | |||||
| """ | |||||
| self._operation = op | |||||
| @property | |||||
| def actual_args(self) -> dict: | |||||
| """Getter of actual args.""" | |||||
| return self._actual_args_list | |||||
| @property | |||||
| def formal_args(self) -> dict: | |||||
| """Get formal args.""" | |||||
| return self._formal_args_list | |||||
| def update_formal_args(self, formal_args: dict): | |||||
| """ | |||||
| Update formal args. | |||||
| Args: | |||||
| formal_args (dict): To be updated args. | |||||
| """ | |||||
| return self._formal_args_list.update(formal_args) | |||||
| @property | |||||
| def input_shape(self): | |||||
| return self._input_shape | |||||
| @property | |||||
| def output_shape(self): | |||||
| return self._output_shape | |||||
| class CodeFragment(Fragment): | |||||
| """ | |||||
| Manage the variables related with code generation. | |||||
| For single operation type node, the variables in `CodeLine` stands for: | |||||
| ```python | |||||
| class Module(nn.Cell): | |||||
| def __init__ (self, ...): | |||||
| super(Module, self).__init__() | |||||
| self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>, | |||||
| <CodeLine.init_trainable_params>) | |||||
| self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>, | |||||
| dtype=<CodeLine._trainable_params[k].dtype>) | |||||
| def construct(self, x, ...): | |||||
| <CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>) | |||||
| ... | |||||
| return output | |||||
| ``` | |||||
| Args: | |||||
| operation (str): Operation name in MindSpore. | |||||
| actual_args (dict): Actual arg values. | |||||
| settings (namedTuple): Code generation setting. | |||||
| """ | |||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||||
| trainable_params=None): | |||||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | |||||
| input_shape=input_shape, output_shape=output_shape, | |||||
| settings=settings) | |||||
| self._trainable_params = dict() # External weights, like Matmul. | |||||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | |||||
| @property | |||||
| def trainable_params(self): | |||||
| return self._trainable_params | |||||
| class ModuleFragment(Fragment): | |||||
| """Manage module type code variables.""" | |||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape): | |||||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | |||||
| input_shape=input_shape, output_shape=output_shape, | |||||
| settings=settings) | |||||
| @@ -0,0 +1,29 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define common utils.""" | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP | |||||
| def is_converted(operation: str): | |||||
| """ | |||||
| Whether convert successful. | |||||
| Args: | |||||
| operation (str): Operation name. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| return operation and SEPARATOR_IN_ONNX_OP not in operation | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Hierarchical tree module.""" | """Hierarchical tree module.""" | ||||
| import re | import re | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .hierarchical_tree import HierarchicalTree | from .hierarchical_tree import HierarchicalTree | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| @@ -36,7 +37,6 @@ def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||||
| """ | """ | ||||
| scope_name = node.scope_name | scope_name = node.scope_name | ||||
| new_name = None | new_name = None | ||||
| parent = "" | |||||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | regex = r"(?P<parent>.+/)(?P<op>\w+)" | ||||
| match = re.match(regex, scope_name) | match = re.match(regex, scope_name) | ||||
| parent = match.group("parent") | parent = match.group("parent") | ||||
| @@ -74,12 +74,13 @@ class HierarchicalTreeFactory: | |||||
| f"Cannot find {node_name}'s input shape." | f"Cannot find {node_name}'s input shape." | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| if isinstance(node_inst, OnnxGraphNode): | if isinstance(node_inst, OnnxGraphNode): | ||||
| node_name_with_scope = _tf_model_node_name_reformat( | |||||
| node_inst, node_name) | |||||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||||
| node_scope_name[node_name] = node_name_with_scope | node_scope_name[node_name] = node_name_with_scope | ||||
| node_name = node_name_with_scope | node_name = node_name_with_scope | ||||
| tree.insert(node_inst, node_name, node_input, node_output) | |||||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||||
| tree.insert(node_inst, node_name) | |||||
| if node_scope_name: | if node_scope_name: | ||||
| return tree, node_scope_name | return tree, node_scope_name | ||||
| return tree | return tree | ||||
| @@ -25,17 +25,18 @@ from treelib import Tree, Node | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | ||||
| from ..common.utils import is_converted | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig | |||||
| from ..constant import SEPARATOR_IN_SCOPE | |||||
| from ..constant import CodeFormatConfig | |||||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| from ..constant import NodeType | from ..constant import NodeType | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| from ...common.exceptions import NodeTypeNotSupport | from ...common.exceptions import NodeTypeNotSupport | ||||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||||
| class HierarchicalTree(Tree): | class HierarchicalTree(Tree): | ||||
| """Define hierarchical tree.""" | """Define hierarchical tree.""" | ||||
| @@ -46,6 +47,8 @@ class HierarchicalTree(Tree): | |||||
| _root_created = False | _root_created = False | ||||
| ROOT_LEVEL = 0 | ROOT_LEVEL = 0 | ||||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||||
| def __init__(self): | def __init__(self): | ||||
| super(HierarchicalTree, self).__init__() | super(HierarchicalTree, self).__init__() | ||||
| self._hierarchical_order = dict() | self._hierarchical_order = dict() | ||||
| @@ -62,6 +65,7 @@ class HierarchicalTree(Tree): | |||||
| self._module_vars = dict() | self._module_vars = dict() | ||||
| # scope name mapping record for easy node searching | # scope name mapping record for easy node searching | ||||
| self._scope_name_map = dict() | self._scope_name_map = dict() | ||||
| self.code_fragment_recorder = dict() | |||||
| @property | @property | ||||
| def tree_identifier(self): | def tree_identifier(self): | ||||
| @@ -82,19 +86,15 @@ class HierarchicalTree(Tree): | |||||
| return None | return None | ||||
| return self._nodes[nid] | return self._nodes[nid] | ||||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], | |||||
| node_name: str, input_shape, output_shape): | |||||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): | |||||
| """ | """ | ||||
| Insert node into hierarchical tree. | Insert node into hierarchical tree. | ||||
| Args: | Args: | ||||
| node_name (str): Node name. | node_name (str): Node name. | ||||
| node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | ||||
| output_shape (tuple): Output tensor shape. | |||||
| input_shape (tuple): Input tensor shape. | |||||
| """ | """ | ||||
| node.add_input_and_output_shape(input_shape, output_shape) | |||||
| scopes = node_name.split(SEPARATOR_IN_SCOPE) | scopes = node_name.split(SEPARATOR_IN_SCOPE) | ||||
| for idx, scope in enumerate(scopes): | for idx, scope in enumerate(scopes): | ||||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | ||||
| @@ -125,10 +125,9 @@ class HierarchicalTree(Tree): | |||||
| tgt_node.precursor_nodes = node.precursor_nodes | tgt_node.precursor_nodes = node.precursor_nodes | ||||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | ||||
| else NodeType.MODULE).value | else NodeType.MODULE).value | ||||
| tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||||
| tgt_node.variable_name = self._get_var_name(identifier) | tgt_node.variable_name = self._get_var_name(identifier) | ||||
| self.create_node( | self.create_node( | ||||
| tag=tgt_node.tag, | |||||
| tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], | |||||
| identifier=identifier, | identifier=identifier, | ||||
| parent=parent, | parent=parent, | ||||
| data=tgt_node | data=tgt_node | ||||
| @@ -276,8 +275,7 @@ class HierarchicalTree(Tree): | |||||
| node.data.replace_with_arg(arg, arg) | node.data.replace_with_arg(arg, arg) | ||||
| return node | return node | ||||
| @staticmethod | |||||
| def _clear_unused_args(node, used_args): | |||||
| def _clear_unused_args(self, node, used_args): | |||||
| """ | """ | ||||
| Clear unused args. | Clear unused args. | ||||
| @@ -290,7 +288,9 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| args_in_code = list(node.data.args_in_code.keys()) | args_in_code = list(node.data.args_in_code.keys()) | ||||
| for arg in args_in_code: | for arg in args_in_code: | ||||
| ori_arg = arg.replace(f"_{node.data.variable_name}", "") | |||||
| ori_arg = arg.replace( | |||||
| f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" | |||||
| ) | |||||
| if ori_arg not in used_args: | if ori_arg not in used_args: | ||||
| node.data.args_in_code.pop(arg) | node.data.args_in_code.pop(arg) | ||||
| return node | return node | ||||
| @@ -323,6 +323,8 @@ class HierarchicalTree(Tree): | |||||
| # 1. Generate args for each node in this level. | # 1. Generate args for each node in this level. | ||||
| if node.data.node_type == NodeType.MODULE.value: | if node.data.node_type == NodeType.MODULE.value: | ||||
| self._create_module_args_and_vars(node, mapper) | self._create_module_args_and_vars(node, mapper) | ||||
| if depth == depths[-1]: | |||||
| self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") | |||||
| # Module merging based on all nodes. | # Module merging based on all nodes. | ||||
| self._module_merging() | self._module_merging() | ||||
| @@ -345,30 +347,29 @@ class HierarchicalTree(Tree): | |||||
| # then assign the created module name to current node, | # then assign the created module name to current node, | ||||
| # and delete unused args. | # and delete unused args. | ||||
| module_name = self._created_module[module_key] | module_name = self._created_module[module_key] | ||||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||||
| module_name) | |||||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||||
| self._preprocess_node_args(nd_inst, module_key) | self._preprocess_node_args(nd_inst, module_key) | ||||
| continue | continue | ||||
| module_name = nd_inst.data.module_name | |||||
| module_name = nd_inst.tag | |||||
| if node_type == NodeType.CLASS.value: | if node_type == NodeType.CLASS.value: | ||||
| module_name = f"{module_name[0].upper()}{module_name[1:]}" | module_name = f"{module_name[0].upper()}{module_name[1:]}" | ||||
| # After node_type and module_name is frozen, | # After node_type and module_name is frozen, | ||||
| # then it's unchangeable. | # then it's unchangeable. | ||||
| module_name = self._module_mgr.get_name(module_name) | module_name = self._module_mgr.get_name(module_name) | ||||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||||
| module_name) | |||||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||||
| # 3. Pre-process node args. | # 3. Pre-process node args. | ||||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | nd_inst = self._preprocess_node_args(nd_inst, module_key) | ||||
| # 4. Post-process child node args. | # 4. Post-process child node args. | ||||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | ||||
| self._postprocess_node_args( | |||||
| self.get_node(scsr_nd_name), module_key) | |||||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||||
| # 5. Generate code. | # 5. Generate code. | ||||
| snippets.add( | |||||
| func(nd_inst, nd_inst.data.module_name, module_key)) | |||||
| snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) | |||||
| code_blocks.extend(snippets) | code_blocks.extend(snippets) | ||||
| @@ -437,7 +438,7 @@ class HierarchicalTree(Tree): | |||||
| module_list = [] | module_list = [] | ||||
| for node_name in node.successors(self.tree_identifier): | for node_name in node.successors(self.tree_identifier): | ||||
| c_nd = self.get_node(node_name) | c_nd = self.get_node(node_name) | ||||
| operator = c_nd.data.op_in_ms or c_nd.data.module_name | |||||
| operator = self.code_fragment_recorder[c_nd.identifier].operation | |||||
| if c_nd.data.node_type != NodeType.OPERATION.value: | if c_nd.data.node_type != NodeType.OPERATION.value: | ||||
| hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | ||||
| @@ -445,14 +446,16 @@ class HierarchicalTree(Tree): | |||||
| operator = self._created_module[hash_key] | operator = self._created_module[hash_key] | ||||
| args = c_nd.data.args_in_code | args = c_nd.data.args_in_code | ||||
| if c_nd.data.node_type == NodeType.OPERATION.value and \ | |||||
| not c_nd.data.convert_successful(): | |||||
| if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( | |||||
| self.code_fragment_recorder[c_nd.identifier].operation): | |||||
| args.update({"input_shape": c_nd.data.input_shape, | args.update({"input_shape": c_nd.data.input_shape, | ||||
| "output_shape": c_nd.data.output_shape}) | "output_shape": c_nd.data.output_shape}) | ||||
| # Generate code statement. | # Generate code statement. | ||||
| expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| expr = ", ".join( | |||||
| [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" | |||||
| for k, v in args.items()] | |||||
| ) | |||||
| code_line = f"{operator}({expr})" | code_line = f"{operator}({expr})" | ||||
| module_list.append(code_line) | module_list.append(code_line) | ||||
| @@ -547,14 +550,16 @@ class HierarchicalTree(Tree): | |||||
| if idx != 0: | if idx != 0: | ||||
| # Get previous node output variable name. | # Get previous node output variable name. | ||||
| ipt_args_in_construct = self._get_previous_opt_var( | |||||
| cur_nd_inst, pre_nd_inst) | |||||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | ||||
| # Set opt variable name. | # Set opt variable name. | ||||
| opt_arg_in_construct = cur_nd_inst.data.opt_var_name | |||||
| opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | ||||
| output_var=opt_arg_in_construct) | |||||
| variable_name=self.code_fragment_recorder[ | |||||
| cur_nd_inst.identifier].declared_var_name, | |||||
| output_var=opt_arg_in_construct, | |||||
| code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) | |||||
| return declare, call | return declare, call | ||||
| @@ -588,7 +593,9 @@ class HierarchicalTree(Tree): | |||||
| if e not in pre_nd.successors(self.tree_identifier): | if e not in pre_nd.successors(self.tree_identifier): | ||||
| while True: | while True: | ||||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | if p_nd.identifier in pre_nd.successors(self.tree_identifier): | ||||
| ipt_lst.append(p_nd.data.opt_var_name) | |||||
| ipt_lst.append( | |||||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||||
| ) | |||||
| break | break | ||||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | pre_nd_name = p_nd.predecessor(self.tree_identifier) | ||||
| if not pre_nd_name: | if not pre_nd_name: | ||||
| @@ -597,7 +604,9 @@ class HierarchicalTree(Tree): | |||||
| p_nd = self.get_node(pre_nd_name) | p_nd = self.get_node(pre_nd_name) | ||||
| continue | continue | ||||
| ipt_lst.append(p_nd.data.opt_var_name) | |||||
| ipt_lst.append( | |||||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||||
| ) | |||||
| return ipt_lst | return ipt_lst | ||||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | def _get_previous_opt_var(self, cur_nd, pre_nd): | ||||
| @@ -619,12 +628,11 @@ class HierarchicalTree(Tree): | |||||
| cur_nd = self.get_node(p_nd[0]) | cur_nd = self.get_node(p_nd[0]) | ||||
| return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | ||||
| def hash_key(self, node, depth: int = 0): | |||||
| def hash_key(self, node): | |||||
| """ | """ | ||||
| Generate hash key for each node. | Generate hash key for each node. | ||||
| Args: | Args: | ||||
| depth (int): Recursion depth. | |||||
| node (Node): Node. | node (Node): Node. | ||||
| Returns: | Returns: | ||||
| @@ -633,13 +641,17 @@ class HierarchicalTree(Tree): | |||||
| scsr_topo_order = [] | scsr_topo_order = [] | ||||
| for s in node.successors(self.tree_identifier): | for s in node.successors(self.tree_identifier): | ||||
| cur_nd = self.get_node(s) | cur_nd = self.get_node(s) | ||||
| if cur_nd.data.hash_key: | |||||
| scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]") | |||||
| continue | |||||
| if cur_nd.data.node_type in {NodeType.MODULE.value, | if cur_nd.data.node_type in {NodeType.MODULE.value, | ||||
| NodeType.FUNC.value, | NodeType.FUNC.value, | ||||
| NodeType.CLASS.value}: | NodeType.CLASS.value}: | ||||
| scsr_topo_order.append(self.hash_key(cur_nd, depth + 1)) | |||||
| if cur_nd.data.hash_key: | |||||
| scsr_topo_order.append(f"({cur_nd.data.hash_key})") | |||||
| continue | |||||
| raise ValueError("Current node doesn't have hash key.") | |||||
| if cur_nd.data.hash_key: | |||||
| scsr_topo_order.append(cur_nd.data.hash_key) | |||||
| continue | continue | ||||
| unique_key = "->".join(scsr_topo_order) | unique_key = "->".join(scsr_topo_order) | ||||
| node.data.hash_key = unique_key | node.data.hash_key = unique_key | ||||
| @@ -675,12 +687,11 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| # All args and value pair in current node module. | # All args and value pair in current node module. | ||||
| module_args = dict() | module_args = dict() | ||||
| module_settings = dict() | |||||
| module_key = self.hash_key(node) | module_key = self.hash_key(node) | ||||
| created = False | created = False | ||||
| if module_key not in self._vars_mgr_in_module: | if module_key not in self._vars_mgr_in_module: | ||||
| self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||||
| self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR | |||||
| self._module_vars[module_key] = [] | self._module_vars[module_key] = [] | ||||
| else: | else: | ||||
| created = True | created = True | ||||
| @@ -688,33 +699,29 @@ class HierarchicalTree(Tree): | |||||
| # Sub-modules in the module could have arg name conflicts. | # Sub-modules in the module could have arg name conflicts. | ||||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | ||||
| nd_inst = self.get_node(successor_name) | nd_inst = self.get_node(successor_name) | ||||
| # Generate variable name here, then | |||||
| # to generate args. | |||||
| # Generation of params must behind variable assigment. | |||||
| if created: | if created: | ||||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||||
| variable_name = self._module_vars[module_key][idx] | |||||
| else: | else: | ||||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | |||||
| variable_name = self._vars_mgr_in_module[module_key].get_name( | |||||
| variable_name) | |||||
| nd_inst.data.variable_name = variable_name | |||||
| variable_name = nd_inst.data.op_name or nd_inst.tag | |||||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||||
| # Generation of params must behind variable assigment. | |||||
| nd_inst.data.param_transform(mapper) | |||||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | |||||
| code_fragment.declared_var_name = variable_name | |||||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | |||||
| module_args.update(nd_inst.data.args_in_code) | module_args.update(nd_inst.data.args_in_code) | ||||
| module_settings.update(nd_inst.data.settings_in_code) | |||||
| if not created: | if not created: | ||||
| self._module_vars[module_key].append( | |||||
| nd_inst.data.variable_name) | |||||
| self._module_vars[module_key].append(variable_name) | |||||
| node.data.args_in_code = module_args | node.data.args_in_code = module_args | ||||
| # Collect module args of `module_key`. | # Collect module args of `module_key`. | ||||
| if module_key not in self._merged_module: | if module_key not in self._merged_module: | ||||
| self._merged_module[module_key] = [node.data.args_in_code] | |||||
| self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] | |||||
| else: | else: | ||||
| self._merged_module[module_key].append(node.data.args_in_code) | |||||
| self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) | |||||
| @staticmethod | @staticmethod | ||||
| def _create_operation_args(node, mapper): | def _create_operation_args(node, mapper): | ||||
| @@ -63,6 +63,10 @@ START_IDX = 0 | |||||
| class GlobalVarNameMgr: | class GlobalVarNameMgr: | ||||
| """Global variable name mgr.""" | """Global variable name mgr.""" | ||||
| def __init__(self): | |||||
| global_op_namespace.clear() | |||||
| global_var_namespace.clear() | |||||
| @staticmethod | @staticmethod | ||||
| def _get_name(name): | def _get_name(name): | ||||
| """Deal with op name.""" | """Deal with op name.""" | ||||
| @@ -87,7 +87,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| module_name = TABLE.get(op_name) | module_name = TABLE.get(op_name) | ||||
| if not module_name: | if not module_name: | ||||
| return None, dict(), dict() | |||||
| return None, dict(), None, dict() | |||||
| pos = module_name.rfind(".") | pos = module_name.rfind(".") | ||||
| try: | try: | ||||
| @@ -101,7 +101,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| # If mapper can not be found, then skip it. | # If mapper can not be found, then skip it. | ||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | err_msg = f"Converting {op_name} failed, see {str(e)}" | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| return None, dict(), dict() | |||||
| return None, dict(), None, dict() | |||||
| try: | try: | ||||
| converter_name = op_name_converter( | converter_name = op_name_converter( | ||||
| @@ -110,13 +110,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| converted_weights = weights_converter( | converted_weights = weights_converter( | ||||
| weights=weights) if weights else dict() | weights=weights) if weights else dict() | ||||
| converted_params.update(converted_weights) | converted_params.update(converted_weights) | ||||
| converted_settings = settings_converter(params=params) | |||||
| converted_settings = settings_converter(params=params, weights=weights) | |||||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | ||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | err_msg = f"Converting {op_name} failed, see {str(e)}" | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| return None, dict(), dict() | |||||
| return None, dict(), None, dict() | |||||
| return converter_name, converted_params, converted_settings | |||||
| return converter_name, converted_params, converted_settings, converted_weights | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Operation mapping setting.""" | |||||
| from collections import namedtuple | |||||
| import numpy as np | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||||
| Tensor = namedtuple("Tensor", ["shape", "dtype", "reference"]) | |||||
| Setting = namedtuple("Setting", ["opt_vars_suffix", | |||||
| "op_ipt_type", | |||||
| "op_extra_input", | |||||
| "op_extra_tensor"]) | |||||
| Setting.__new__.__defaults__ = ("_opt", InputType.TENSOR.value, dict(), None) | |||||
| def get_dtype(tensor: np.ndarray): | |||||
| """Get tensor dtype.""" | |||||
| if tensor.dtype == np.float16: | |||||
| return "mindspore.float16" | |||||
| return "mindspore.float32" | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class BatchNormMapper(ONNXToMindSporeMapper): | class BatchNormMapper(ONNXToMindSporeMapper): | ||||
| @@ -39,4 +40,4 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -16,6 +16,7 @@ | |||||
| import re | import re | ||||
| import numpy as np | import numpy as np | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| def _convert_padding(**kwargs): | def _convert_padding(**kwargs): | ||||
| @@ -35,6 +36,7 @@ def _convert_padding(**kwargs): | |||||
| class ConvMapper(ONNXToMindSporeMapper): | class ConvMapper(ONNXToMindSporeMapper): | ||||
| """Conv2d mapper.""" | """Conv2d mapper.""" | ||||
| @staticmethod | @staticmethod | ||||
| def convert_params_torch(**kwargs): | def convert_params_torch(**kwargs): | ||||
| """Convert params from PyTorch to MindSpore""" | """Convert params from PyTorch to MindSpore""" | ||||
| @@ -148,4 +150,4 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class DenseMapper(ONNXToMindSporeMapper): | class DenseMapper(ONNXToMindSporeMapper): | ||||
| @@ -41,4 +42,4 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class FlattenMapper(ONNXToMindSporeMapper): | class FlattenMapper(ONNXToMindSporeMapper): | ||||
| @@ -33,4 +34,4 @@ class FlattenMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class GlobalPoolMapper(ONNXToMindSporeMapper): | class GlobalPoolMapper(ONNXToMindSporeMapper): | ||||
| @@ -25,8 +26,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||||
| op_name = 'nn.AvgPool{}d' | op_name = 'nn.AvgPool{}d' | ||||
| else: | else: | ||||
| op_name = 'nn.MaxPool{}d' | op_name = 'nn.MaxPool{}d' | ||||
| dim = 1 if len(kwargs['params']['input_shape']) == 3\ | |||||
| else 2 | |||||
| dim = 1 if len(kwargs['params']['input_shape']) == 3 else 2 | |||||
| return op_name.format(dim) | return op_name.format(dim) | ||||
| @staticmethod | @staticmethod | ||||
| @@ -49,4 +49,4 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting, Tensor, get_dtype | |||||
| class MatMulMapper(ONNXToMindSporeMapper): | class MatMulMapper(ONNXToMindSporeMapper): | ||||
| @@ -33,4 +34,12 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| weights = kwargs.get("weights") | |||||
| if not weights: | |||||
| return Setting() | |||||
| tensor, ref = None, "" | |||||
| for t_name, t_value in weights.items(): | |||||
| tensor = t_value | |||||
| ref = t_name | |||||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||||
| dtype=get_dtype(tensor), reference=ref)) | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| def _padding_format_convert(padding: list): | def _padding_format_convert(padding: list): | ||||
| @@ -77,4 +78,4 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class PoolMapper(ONNXToMindSporeMapper): | class PoolMapper(ONNXToMindSporeMapper): | ||||
| @@ -49,4 +50,4 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class ReLUMapper(ONNXToMindSporeMapper): | class ReLUMapper(ONNXToMindSporeMapper): | ||||
| @@ -45,4 +46,4 @@ class ReLUMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class SoftmaxMapper(ONNXToMindSporeMapper): | class SoftmaxMapper(ONNXToMindSporeMapper): | ||||
| @@ -37,4 +38,4 @@ class SoftmaxMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| return Setting() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting, Tensor, get_dtype | |||||
| class AddMapper(ONNXToMindSporeMapper): | class AddMapper(ONNXToMindSporeMapper): | ||||
| @@ -33,4 +34,12 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| return dict() | |||||
| weights = kwargs.get("weights") | |||||
| if not weights: | |||||
| return Setting() | |||||
| tensor, ref = None, "" | |||||
| for t_name, t_value in weights.items(): | |||||
| tensor = t_value | |||||
| ref = t_name | |||||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||||
| dtype=get_dtype(tensor), reference=ref)) | |||||
| @@ -15,6 +15,7 @@ | |||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | from mindinsight.mindconverter.graph_based_converter.constant import InputType | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class ConcatMapper(ONNXToMindSporeMapper): | class ConcatMapper(ONNXToMindSporeMapper): | ||||
| @@ -36,4 +37,4 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| input_type = InputType.LIST.value | input_type = InputType.LIST.value | ||||
| return {'input_type': input_type} | |||||
| return Setting(op_ipt_type=input_type) | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | class ReduceMeanMapper(ONNXToMindSporeMapper): | ||||
| @@ -40,4 +41,4 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||||
| axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) | axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) | ||||
| else: | else: | ||||
| axis = tuple() | axis = tuple() | ||||
| return {'values': {'axis': axis}} | |||||
| return Setting(op_extra_input={'axis': axis}) | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from ...base import ONNXToMindSporeMapper | from ...base import ONNXToMindSporeMapper | ||||
| from ...gen_setting import Setting | |||||
| class TransposeMapper(ONNXToMindSporeMapper): | class TransposeMapper(ONNXToMindSporeMapper): | ||||
| @@ -40,4 +41,4 @@ class TransposeMapper(ONNXToMindSporeMapper): | |||||
| perm = tuple(perm) | perm = tuple(perm) | ||||
| converted_params['input_perm'] = perm | converted_params['input_perm'] = perm | ||||
| return {'values': converted_params} | |||||
| return Setting(op_extra_input=converted_params) | |||||
| @@ -15,10 +15,13 @@ | |||||
| """Define graph entity.""" | """Define graph entity.""" | ||||
| import abc | import abc | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from copy import deepcopy | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from ..constant import SEPARATOR_IN_ONNX_OP | |||||
| from ..common.code_fragment import CodeFragment | |||||
| from ..constant import NodeType, InputType | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| class GraphParser(metaclass=abc.ABCMeta): | class GraphParser(metaclass=abc.ABCMeta): | ||||
| @@ -287,26 +290,10 @@ class GraphNode(abc.ABC): | |||||
| self._op_params = dict() | self._op_params = dict() | ||||
| self._scope_name = None | self._scope_name = None | ||||
| self._op_shape = None | self._op_shape = None | ||||
| # Operation in mindspore. | |||||
| self._op_in_ms = None | |||||
| # Params in mindspore. | |||||
| self._params_in_ms = dict() | |||||
| # Settings in mindspore. | |||||
| self._settings_in_ms = dict() | |||||
| # Node type of current node, e.g. class, module, operation. | # Node type of current node, e.g. class, module, operation. | ||||
| self._node_type = None | self._node_type = None | ||||
| # Tag name on tree. | |||||
| self._tag_on_tree = None | |||||
| # Function, class or operation needed args. | # Function, class or operation needed args. | ||||
| self._args_in_code = dict() | self._args_in_code = dict() | ||||
| # Operation needed settings. | |||||
| self._settings_in_code = dict() | |||||
| # Variable name declared in init block. | |||||
| self._variable_name = None | |||||
| # Output variable name declared in construct block. | |||||
| self._opt_var_name = None | |||||
| # Function or class name in code. | |||||
| self._module_name = None | |||||
| # Unique key of node. | # Unique key of node. | ||||
| self._hash_key = None | self._hash_key = None | ||||
| # Input shape of current op. | # Input shape of current op. | ||||
| @@ -317,37 +304,18 @@ class GraphNode(abc.ABC): | |||||
| self._weight = None | self._weight = None | ||||
| @property | @property | ||||
| def opt_var_name(self): | |||||
| def weight(self): | |||||
| return self._weight | |||||
| @staticmethod | |||||
| def get_opt_var_name(variable_name): | |||||
| """ | """ | ||||
| Output variable name. | Output variable name. | ||||
| Returns: | Returns: | ||||
| str, variable name. | str, variable name. | ||||
| """ | """ | ||||
| return f"{self.variable_name}_opt" | |||||
| @opt_var_name.setter | |||||
| def opt_var_name(self, v): | |||||
| """ | |||||
| Set variable name. | |||||
| Args: | |||||
| v (str): Name. | |||||
| """ | |||||
| self._opt_var_name = v | |||||
| @property | |||||
| def op_in_ms(self): | |||||
| """ | |||||
| Operation in mindspore. | |||||
| Returns: | |||||
| str, operation name. | |||||
| """ | |||||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms: | |||||
| return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".") | |||||
| return self._op_in_ms | |||||
| return f"{variable_name}_opt" | |||||
| @property | @property | ||||
| def args_in_code(self): | def args_in_code(self): | ||||
| @@ -370,27 +338,6 @@ class GraphNode(abc.ABC): | |||||
| """ | """ | ||||
| self._args_in_code = args | self._args_in_code = args | ||||
| @property | |||||
| def settings_in_code(self): | |||||
| """ | |||||
| Settings in code. | |||||
| Returns: | |||||
| dict, settings. | |||||
| """ | |||||
| return self._settings_in_code | |||||
| @settings_in_code.setter | |||||
| def settings_in_code(self, settings): | |||||
| """ | |||||
| Settings in code. | |||||
| Args: | |||||
| settings(dict): Settings. | |||||
| """ | |||||
| self._settings_in_code = settings | |||||
| @property | @property | ||||
| def input_shape(self): | def input_shape(self): | ||||
| """ | """ | ||||
| @@ -411,16 +358,6 @@ class GraphNode(abc.ABC): | |||||
| """ | """ | ||||
| return self._opt_shape | return self._opt_shape | ||||
| @property | |||||
| def tag(self): | |||||
| """Tag on hierarchical tree.""" | |||||
| return self._tag_on_tree | |||||
| @tag.setter | |||||
| def tag(self, t): | |||||
| """Tag on hierarchical tree.""" | |||||
| self._tag_on_tree = t | |||||
| def is_empty(self): | def is_empty(self): | ||||
| """ | """ | ||||
| Whether is empty. | Whether is empty. | ||||
| @@ -536,7 +473,7 @@ class GraphNode(abc.ABC): | |||||
| """Replace actual parameter with formal parameter.""" | """Replace actual parameter with formal parameter.""" | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def _get_arg_name(self, arg): | |||||
| def _get_arg_name(self, arg, variable_name): | |||||
| """Get arg name for func or class.""" | """Get arg name for func or class.""" | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| @@ -553,13 +490,8 @@ class GraphNode(abc.ABC): | |||||
| def real_name(self, **kwargs): | def real_name(self, **kwargs): | ||||
| """Setter of `real_name`.""" | """Setter of `real_name`.""" | ||||
| @property | |||||
| @abc.abstractmethod | |||||
| def variable_name(self): | |||||
| """Getter of `variable_name`.""" | |||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||||
| """Graph node to MindSpore code.""" | """Graph node to MindSpore code.""" | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| @@ -570,40 +502,86 @@ class GraphNode(abc.ABC): | |||||
| def add_input_and_output_shape(self, input_shape, output_shape): | def add_input_and_output_shape(self, input_shape, output_shape): | ||||
| """Add the node input shape.""" | """Add the node input shape.""" | ||||
| @abc.abstractmethod | |||||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||||
| """Make node_type can not be changed.""" | |||||
| @staticmethod | |||||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||||
| """ | |||||
| Generate input with args and settings in construct. | |||||
| @abc.abstractmethod | |||||
| def convert_successful(self): | |||||
| """Whether convert successful.""" | |||||
| Args: | |||||
| ipt_args_in_construct (str): Input args in construct. | |||||
| settings (Setting): Settings in operator. | |||||
| Returns: | |||||
| str, args of each node in generated construct statement. | |||||
| """ | |||||
| if settings and settings.op_ipt_type: | |||||
| input_type = settings.op_ipt_type | |||||
| if input_type == InputType.TENSOR.value: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| elif input_type == InputType.LIST.value: | |||||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||||
| else: | |||||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||||
| else: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| if settings and settings.op_extra_input: | |||||
| settings_value = settings.op_extra_input | |||||
| if settings_value: | |||||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||||
| def param_transform(self, mapper: Mapper): | |||||
| return ipt_args_settings_in_construct | |||||
| def param_transform(self, mapper: Mapper, variable_name): | |||||
| """ | """ | ||||
| Transform param in pytorch operation into mindspore. | |||||
| Transform param in PyTorch operation into MindSpore. | |||||
| Args: | Args: | ||||
| variable_name (str): Variable name. | |||||
| mapper (ONNXToMindSporeMapper): Mapper between onnx operation | mapper (ONNXToMindSporeMapper): Mapper between onnx operation | ||||
| and mindspore. | |||||
| and MindSpore. | |||||
| Returns: | Returns: | ||||
| dict, transformed params. | dict, transformed params. | ||||
| """ | """ | ||||
| import copy | |||||
| params = copy.deepcopy(self._op_params) | |||||
| if self._node_type != NodeType.OPERATION.value: | |||||
| args = deepcopy(self._args_in_code) | |||||
| self._args_in_code = dict() | |||||
| for arg, value in args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||||
| return CodeFragment(operation="", actual_args=args, settings=None, | |||||
| input_shape=self.input_shape, output_shape=self.output_shape) | |||||
| if self.transformed: | |||||
| raise ValueError("Already transformed.") | |||||
| params = deepcopy(self._op_params) | |||||
| params.update({"input_shape": self.input_shape, | params.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, | |||||
| params=params, | |||||
| weights=self._weight) | |||||
| if op_name_in_mindspore: | |||||
| self._op_in_ms = op_name_in_mindspore | |||||
| self._params_in_ms = ms_params | |||||
| self._settings_in_ms = ms_settings | |||||
| ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, | |||||
| params=params, | |||||
| weights=self._weight) | |||||
| if ms_op: | |||||
| code_fragment = CodeFragment(operation=ms_op, | |||||
| actual_args=ms_params, | |||||
| settings=ms_settings, | |||||
| input_shape=self.input_shape, | |||||
| output_shape=self.output_shape, | |||||
| trainable_params=ms_weights) | |||||
| else: | else: | ||||
| self._op_in_ms = self._op_name | |||||
| self._params_in_ms = self._op_params | |||||
| self._settings_in_ms = dict() | |||||
| code_fragment = CodeFragment(operation=self._op_name, | |||||
| actual_args=self._op_params, | |||||
| settings=None, | |||||
| input_shape=self.input_shape, | |||||
| output_shape=self.output_shape, | |||||
| trainable_params=self._weight) | |||||
| for arg, value in code_fragment.actual_args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||||
| self.transformed = True | |||||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||||
| return code_fragment | |||||
| @@ -38,7 +38,6 @@ class PyTorchGraphParser(GraphParser): | |||||
| error = FileNotFoundError("`model_path` must be assigned with " | error = FileNotFoundError("`model_path` must be assigned with " | ||||
| "an existed file path.") | "an existed file path.") | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | raise error | ||||
| try: | try: | ||||
| @@ -21,24 +21,18 @@ from ..constant import SEPARATOR_IN_SCOPE, NodeType | |||||
| class InputNode(GraphNode): | class InputNode(GraphNode): | ||||
| """ | """ | ||||
| Pytorch Input Node. | |||||
| PyTorch Input Node. | |||||
| Args: | Args: | ||||
| input_shape: Input shape of module. | input_shape: Input shape of module. | ||||
| """ | """ | ||||
| def convert_successful(self): | |||||
| """ | |||||
| Whether convert successful. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| return False | |||||
| def _get_arg_name(self, arg, variable_name): | |||||
| raise NotImplementedError() | |||||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||||
| pass | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||||
| raise NotImplementedError() | |||||
| def _get_raw_params(self, node): | def _get_raw_params(self, node): | ||||
| pass | pass | ||||
| @@ -56,9 +50,6 @@ class InputNode(GraphNode): | |||||
| def replace_with_arg(self, src_arg, tgt_arg): | def replace_with_arg(self, src_arg, tgt_arg): | ||||
| pass | pass | ||||
| def _get_arg_name(self, arg): | |||||
| pass | |||||
| def add_input_and_output_shape(self, input_shape, output_shape): | def add_input_and_output_shape(self, input_shape, output_shape): | ||||
| pass | pass | ||||
| @@ -116,15 +107,8 @@ class InputNode(GraphNode): | |||||
| def real_name(self): | def real_name(self): | ||||
| return | return | ||||
| @property | |||||
| def variable_name(self): | |||||
| return | |||||
| def to_ir(self): | def to_ir(self): | ||||
| """ | """ | ||||
| No need to implement for now. | No need to implement for now. | ||||
| """ | """ | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||||
| raise NotImplementedError() | |||||
| @@ -22,7 +22,6 @@ from .onnx_graph_node import OnnxGraphNode | |||||
| from .graph_parser import TFGraphParser | from .graph_parser import TFGraphParser | ||||
| from .onnx_utils import OnnxDataLoader | from .onnx_utils import OnnxDataLoader | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| "onnx::Add": "Add", | "onnx::Add": "Add", | ||||
| "onnx::Flatten": "Flatten", | "onnx::Flatten": "Flatten", | ||||
| @@ -13,14 +13,13 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define ONNX graph node.""" | """Define ONNX graph node.""" | ||||
| from importlib import import_module | |||||
| from copy import deepcopy | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..common.utils import is_converted | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType | |||||
| from ..mapper.base import Mapper | |||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||||
| SEPARATOR_IN_ONNX_OP | |||||
| class OnnxGraphNode(GraphNode): | class OnnxGraphNode(GraphNode): | ||||
| @@ -39,16 +38,13 @@ class OnnxGraphNode(GraphNode): | |||||
| self._op_params = self._get_raw_params(node.raw_node) if node else None | self._op_params = self._get_raw_params(node.raw_node) if node else None | ||||
| self._op_name = "onnx::" + node.op_type if node else None | self._op_name = "onnx::" + node.op_type if node else None | ||||
| self._scope_name = node.scope_name if node else None | self._scope_name = node.scope_name if node else None | ||||
| self._opt_var_name = None | |||||
| self._variable_name = self._extract_var_name(self._scope_name) | |||||
| self._module_name = None | |||||
| self._weight = weight | self._weight = weight | ||||
| def clear_args_of_declaration(self): | def clear_args_of_declaration(self): | ||||
| """Clear `self._args_in_code`.""" | """Clear `self._args_in_code`.""" | ||||
| self._args_in_code = dict() | self._args_in_code = dict() | ||||
| def _get_arg_name(self, arg): | |||||
| def _get_arg_name(self, arg, variable_name): | |||||
| """ | """ | ||||
| Get arg name. | Get arg name. | ||||
| @@ -58,7 +54,7 @@ class OnnxGraphNode(GraphNode): | |||||
| Returns: | Returns: | ||||
| str, arg name in function or class declaration. | str, arg name in function or class declaration. | ||||
| """ | """ | ||||
| return f"{arg}_{self._variable_name}" | |||||
| return f"{arg}_{variable_name}" | |||||
| @property | @property | ||||
| def hash_key(self): | def hash_key(self): | ||||
| @@ -84,51 +80,6 @@ class OnnxGraphNode(GraphNode): | |||||
| """ | """ | ||||
| self._hash_key = h | self._hash_key = h | ||||
| @property | |||||
| def variable_name(self): | |||||
| """ | |||||
| Variable name. | |||||
| Returns: | |||||
| str, variable name declared in init. | |||||
| """ | |||||
| return self._variable_name | |||||
| @variable_name.setter | |||||
| def variable_name(self, v): | |||||
| """ | |||||
| Setter of variable name. | |||||
| Args: | |||||
| v (str): Variable name. | |||||
| """ | |||||
| self._variable_name = v | |||||
| @property | |||||
| def module_name(self): | |||||
| """ | |||||
| Module name. | |||||
| Returns: | |||||
| str, module name. | |||||
| """ | |||||
| if not self._module_name_frozen: | |||||
| module_name = self.tag | |||||
| return module_name | |||||
| return self._module_name | |||||
| def _froze_module_name(self, m): | |||||
| """ | |||||
| Once module_name is set, then it's unchangeable. | |||||
| Args: | |||||
| m (str): Module name. | |||||
| """ | |||||
| if not self._module_name_frozen: | |||||
| self._module_name = m | |||||
| self._module_name_frozen = True | |||||
| @property | @property | ||||
| def op_name(self): | def op_name(self): | ||||
| """ | """ | ||||
| @@ -154,15 +105,13 @@ class OnnxGraphNode(GraphNode): | |||||
| self._ipt_shape = input_shape | self._ipt_shape = input_shape | ||||
| self._opt_shape = output_shape | self._opt_shape = output_shape | ||||
| def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args): | |||||
| def _add_tensor_args_to_code(self, op_name: str, settings, declare, args, variable_name): | |||||
| """ | """ | ||||
| Add nn used tensors to args in init and construct blocks. | Add nn used tensors to args in init and construct blocks. | ||||
| Args: | Args: | ||||
| op_name (str): Add the tensor to args if the current node has this | op_name (str): Add the tensor to args if the current node has this | ||||
| op_name. | |||||
| t_identifier (str): The unique string appeared in the target tensor | |||||
| name. | |||||
| op_name. | |||||
| declare (str): Declare statement generated in to_code(). | declare (str): Declare statement generated in to_code(). | ||||
| args (str): Args statement generated in to_code(). | args (str): Args statement generated in to_code(). | ||||
| @@ -172,103 +121,68 @@ class OnnxGraphNode(GraphNode): | |||||
| """ | """ | ||||
| if not self._op_name == op_name: | if not self._op_name == op_name: | ||||
| return declare, args | return declare, args | ||||
| declare_list = [] | |||||
| tensor = None | |||||
| # find target tensor | |||||
| for t_name, t_value in self._weight.items(): | |||||
| if t_identifier in t_name: | |||||
| tensor = t_value | |||||
| break | |||||
| if tensor is None: | |||||
| if not settings or not settings.op_extra_tensor: | |||||
| return declare, args | return declare, args | ||||
| declare_list.append(declare) | |||||
| declare_t = f"self.{self._variable_name}_w = Tensor(" \ | |||||
| f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)" | |||||
| declare_list = [declare] | |||||
| declare_t = f"self.{variable_name}_w = Tensor(" \ | |||||
| f"np.random.uniform(0, 1, {str(settings.op_extra_tensor.shape)}), " \ | |||||
| f"{settings.op_extra_tensor.dtype})" | |||||
| declare_list.append(declare_t) | declare_list.append(declare_t) | ||||
| args += f", self.{self._variable_name}_w" | |||||
| args += f", self.{variable_name}_w" | |||||
| return declare_list, args | return declare_list, args | ||||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, | |||||
| code_fragment): | |||||
| """ | """ | ||||
| Generate statements. | Generate statements. | ||||
| Args: | Args: | ||||
| variable_name (str): Variable name. | |||||
| ipt_args_in_construct (str): Args of input. | ipt_args_in_construct (str): Args of input. | ||||
| output_var (str): Output variable name in construct. | output_var (str): Output variable name in construct. | ||||
| code_fragment (CodeFragment): CodeFragment instance. | |||||
| Returns: | Returns: | ||||
| Union[str, str], declare in init and call in construct. | Union[str, str], declare in init and call in construct. | ||||
| """ | """ | ||||
| operator = self.op_in_ms or self.module_name | |||||
| self._opt_var_name = output_var | |||||
| operator = code_fragment.operation | |||||
| args = self.args_in_code | args = self.args_in_code | ||||
| settings = self.settings_in_code | |||||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||||
| settings = code_fragment.code_setting | |||||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||||
| args.update({"input_shape": self.input_shape, | args.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| if self._node_type == NodeType.OPERATION.value: | if self._node_type == NodeType.OPERATION.value: | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = \ | |||||
| self._generate_ipt_args_settings_in_construct( | |||||
| ipt_args_in_construct, | |||||
| settings) | |||||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||||
| ipt_args_in_construct, settings) | |||||
| else: | else: | ||||
| # When it's type is module, class or func, | # When it's type is module, class or func, | ||||
| # it's not necessary to replace var. | # it's not necessary to replace var. | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = ipt_args_in_construct | ipt_args_settings_in_construct = ipt_args_in_construct | ||||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||||
| if SEPARATOR_IN_ONNX_OP in operator: | |||||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||||
| declare = f"self.{variable_name} = {operator}({expr})" | |||||
| # Extra Tensor generator for nn.MatMul | # Extra Tensor generator for nn.MatMul | ||||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | ||||
| 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | |||||
| 'onnx::MatMul', settings, declare, ipt_args_settings_in_construct, variable_name) | |||||
| # Extra Tensor generator for onnx::Add | # Extra Tensor generator for onnx::Add | ||||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | ||||
| 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||||
| 'onnx::Add', settings, declare, ipt_args_settings_in_construct, variable_name) | |||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||||
| return declare, call | return declare, call | ||||
| @staticmethod | |||||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||||
| """ | |||||
| Generate input with args and settings in construct. | |||||
| Args: | |||||
| ipt_args_in_construct(str): Input args in construct. | |||||
| settings(dict): Settings in operator. | |||||
| Returns: | |||||
| str, args of each node in generated construct statement. | |||||
| """ | |||||
| if settings.get('input_type'): | |||||
| input_type = settings['input_type'] | |||||
| if input_type == InputType.TENSOR.value: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| elif input_type == InputType.LIST.value: | |||||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||||
| else: | |||||
| raise NodeInputTypeNotSupport( | |||||
| f"Input type[{input_type}] is not supported now.") | |||||
| else: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| if settings.get('values'): | |||||
| settings_value = settings['values'] | |||||
| if settings_value: | |||||
| settings_in_construct = ', '.join( | |||||
| [f"{setting_val}" for _, setting_val in settings_value.items()]) | |||||
| ipt_args_settings_in_construct = ', '.join( | |||||
| (ipt_args_settings_in_construct, settings_in_construct)) | |||||
| return ipt_args_settings_in_construct | |||||
| def to_ir(self): | def to_ir(self): | ||||
| """No need to implement for now.""" | """No need to implement for now.""" | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -284,7 +198,7 @@ class OnnxGraphNode(GraphNode): | |||||
| Returns: | Returns: | ||||
| dict, raw params. | dict, raw params. | ||||
| """ | """ | ||||
| import onnx | |||||
| onnx = import_module("onnx") | |||||
| raw_params = dict() | raw_params = dict() | ||||
| @@ -318,62 +232,3 @@ class OnnxGraphNode(GraphNode): | |||||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | ||||
| RIGHT_BUCKET, "") | RIGHT_BUCKET, "") | ||||
| return var | return var | ||||
| def param_transform(self, mapper: Mapper): | |||||
| """ | |||||
| Transform tensorflow params into mindspore. | |||||
| Args: | |||||
| mapper (Mapper): Mapper of params. | |||||
| """ | |||||
| if self._node_type != NodeType.OPERATION.value: | |||||
| args = deepcopy(self._args_in_code) | |||||
| self._args_in_code = dict() | |||||
| for arg, value in args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg)] = value | |||||
| return None, None | |||||
| if not self.transformed: | |||||
| _, _, _ = super(OnnxGraphNode, self).param_transform(mapper) | |||||
| for arg, value in self._params_in_ms.items(): | |||||
| self._args_in_code[self._get_arg_name(arg)] = value | |||||
| for arg, value in self._settings_in_ms.items(): | |||||
| self._settings_in_code[arg] = value | |||||
| self.transformed = True | |||||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||||
| """ | |||||
| Froze node type and module name. | |||||
| After node_type is frozen, then the `module_name` | |||||
| will be affected when `node_type` is `class`. | |||||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||||
| Args: | |||||
| module_name: Modified module name. | |||||
| node_type (str): Node type, class of func. | |||||
| """ | |||||
| if not self._type_frozen: | |||||
| self._node_type = node_type | |||||
| self._type_frozen = True | |||||
| if not self._module_name_frozen: | |||||
| self._froze_module_name(module_name) | |||||
| def convert_successful(self): | |||||
| """ | |||||
| Whether convert successfully. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||||
| return True | |||||
| return False | |||||
| @@ -87,7 +87,8 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None | |||||
| inputs_as_nchw=None | inputs_as_nchw=None | ||||
| ) | ) | ||||
| opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map') | opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map') | ||||
| opt_map.pop(('Conv', 'BatchNormalization')) | |||||
| if ('Conv', 'BatchNormalization') in opt_map: | |||||
| opt_map.pop(('Conv', 'BatchNormalization')) | |||||
| onnx_graph = optimizer.optimize_graph(g) | onnx_graph = optimizer.optimize_graph(g) | ||||
| model_proto = onnx_graph.make_model("converted from {}".format(model_path)) | model_proto = onnx_graph.make_model("converted from {}".format(model_path)) | ||||
| @@ -228,8 +229,7 @@ class OnnxNode(BaseNode): | |||||
| """ | """ | ||||
| def __init__(self, raw_node): | def __init__(self, raw_node): | ||||
| super(OnnxNode, self).__init__( | |||||
| node_name=raw_node.name, op_type=raw_node.op_type) | |||||
| super(OnnxNode, self).__init__(node_name=raw_node.name, op_type=raw_node.op_type) | |||||
| self.raw_node = raw_node | self.raw_node = raw_node | ||||
| self.params = ParamsAttribute(raw_node.attribute, raw_node) | self.params = ParamsAttribute(raw_node.attribute, raw_node) | ||||
| self.scope_name = None | self.scope_name = None | ||||
| @@ -99,8 +99,8 @@ class PyTorchGraph(Graph): | |||||
| for item in input_shape: | for item in input_shape: | ||||
| if not isinstance(item, int): | if not isinstance(item, int): | ||||
| err_msg = f"Only support model with one input now, " \ | |||||
| f"and each shape value in `input_shape` should be int." | |||||
| err_msg = "Only support model with one input now, " \ | |||||
| "and each shape value in `input_shape` should be int." | |||||
| log.error(err_msg) | log.error(err_msg) | ||||
| raise ValueError(err_msg) | raise ValueError(err_msg) | ||||
| @@ -13,14 +13,11 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define PyTorch graph node.""" | """Define PyTorch graph node.""" | ||||
| from copy import deepcopy | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..common.utils import is_converted | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | ||||
| SEPARATOR_IN_ONNX_OP, InputType | |||||
| from ..mapper.base import Mapper | |||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| SEPARATOR_IN_ONNX_OP | |||||
| class PyTorchGraphNode(GraphNode): | class PyTorchGraphNode(GraphNode): | ||||
| @@ -40,9 +37,6 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._op_params = self._get_raw_params(node) | self._op_params = self._get_raw_params(node) | ||||
| self._op_name = node.kind() if node else None | self._op_name = node.kind() if node else None | ||||
| self._scope_name = node.scopeName() if node else None | self._scope_name = node.scopeName() if node else None | ||||
| self._opt_var_name = None | |||||
| self._variable_name = self._extract_var_name(self._scope_name) | |||||
| self._module_name = None | |||||
| self._weight = weight | self._weight = weight | ||||
| def clear_args_of_declaration(self): | def clear_args_of_declaration(self): | ||||
| @@ -51,7 +45,7 @@ class PyTorchGraphNode(GraphNode): | |||||
| """ | """ | ||||
| self._args_in_code = dict() | self._args_in_code = dict() | ||||
| def _get_arg_name(self, arg): | |||||
| def _get_arg_name(self, arg, variable_name): | |||||
| """ | """ | ||||
| Get arg name. | Get arg name. | ||||
| @@ -61,7 +55,7 @@ class PyTorchGraphNode(GraphNode): | |||||
| Returns: | Returns: | ||||
| str, arg name in function or class declaration. | str, arg name in function or class declaration. | ||||
| """ | """ | ||||
| return f"{arg}_{self._variable_name}" | |||||
| return f"{arg}_{variable_name}" | |||||
| @property | @property | ||||
| def hash_key(self): | def hash_key(self): | ||||
| @@ -88,53 +82,6 @@ class PyTorchGraphNode(GraphNode): | |||||
| """ | """ | ||||
| self._hash_key = h | self._hash_key = h | ||||
| @property | |||||
| def variable_name(self): | |||||
| """ | |||||
| Variable name. | |||||
| Returns: | |||||
| str, variable name declared in init. | |||||
| """ | |||||
| return self._variable_name | |||||
| @variable_name.setter | |||||
| def variable_name(self, v): | |||||
| """ | |||||
| Setter of variable name. | |||||
| Args: | |||||
| v (str): Variable name. | |||||
| """ | |||||
| self._variable_name = v | |||||
| @property | |||||
| def module_name(self): | |||||
| """ | |||||
| Module name. | |||||
| Returns: | |||||
| str, module name. | |||||
| """ | |||||
| if not self._module_name_frozen: | |||||
| module_name = self.tag | |||||
| return module_name | |||||
| return self._module_name | |||||
| def _froze_module_name(self, m): | |||||
| """ | |||||
| Once module_name is set, then it's unchangeable. | |||||
| Args: | |||||
| m (str): Module name. | |||||
| """ | |||||
| if not self._module_name_frozen: | |||||
| self._module_name = m | |||||
| self._module_name_frozen = True | |||||
| @property | @property | ||||
| def op_name(self): | def op_name(self): | ||||
| """ | """ | ||||
| @@ -172,72 +119,47 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._ipt_shape = input_shape | self._ipt_shape = input_shape | ||||
| self._opt_shape = output_shape | self._opt_shape = output_shape | ||||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||||
| """ | """ | ||||
| Generate statements. | Generate statements. | ||||
| Args: | Args: | ||||
| variable_name (str): Variable name. | |||||
| ipt_args_in_construct (str): Args of input. | ipt_args_in_construct (str): Args of input. | ||||
| output_var (str): Output variable name in construct. | output_var (str): Output variable name in construct. | ||||
| code_fragment (CodeFragment): CodeFragment instance. | |||||
| Returns: | Returns: | ||||
| Union[str, str], declare in init and call in construct. | Union[str, str], declare in init and call in construct. | ||||
| """ | """ | ||||
| operator = self.op_in_ms or self.module_name | |||||
| self._opt_var_name = output_var | |||||
| operator = code_fragment.operation | |||||
| args = self.args_in_code | args = self.args_in_code | ||||
| settings = self.settings_in_code | |||||
| settings = code_fragment.code_setting | |||||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||||
| args.update({"input_shape": self.input_shape, | args.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| if self._node_type == NodeType.OPERATION.value: | if self._node_type == NodeType.OPERATION.value: | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, | |||||
| settings) | |||||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||||
| ipt_args_in_construct, settings) | |||||
| else: | else: | ||||
| # When it's type is module, class or func, | # When it's type is module, class or func, | ||||
| # it's not necessary to replace var. | # it's not necessary to replace var. | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | for k, v in args.items()]) | ||||
| ipt_args_settings_in_construct = ipt_args_in_construct | ipt_args_settings_in_construct = ipt_args_in_construct | ||||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||||
| return declare, call | |||||
| @staticmethod | |||||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||||
| """ | |||||
| Generate input with args and settings in construct. | |||||
| Args: | |||||
| ipt_args_in_construct(str): input args in construct. | |||||
| settings(dict): settings in operator. | |||||
| """ | |||||
| if settings.get('input_type'): | |||||
| input_type = settings['input_type'] | |||||
| if input_type == InputType.TENSOR.value: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| elif input_type == InputType.LIST.value: | |||||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||||
| else: | |||||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||||
| else: | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| if SEPARATOR_IN_ONNX_OP in operator: | |||||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||||
| if settings.get('values'): | |||||
| settings_value = settings['values'] | |||||
| if settings_value: | |||||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||||
| declare = f"self.{variable_name} = {operator}({expr})" | |||||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||||
| return ipt_args_settings_in_construct | |||||
| return declare, call | |||||
| def to_ir(self): | def to_ir(self): | ||||
| """ | """ | ||||
| @@ -288,62 +210,3 @@ class PyTorchGraphNode(GraphNode): | |||||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | ||||
| RIGHT_BUCKET, "") | RIGHT_BUCKET, "") | ||||
| return var | return var | ||||
| def param_transform(self, mapper: Mapper): | |||||
| """ | |||||
| Transform torch params into mindspore. | |||||
| Args: | |||||
| mapper (Mapper): Mapper of params. | |||||
| """ | |||||
| if self._node_type != NodeType.OPERATION.value: | |||||
| args = deepcopy(self._args_in_code) | |||||
| self._args_in_code = dict() | |||||
| for arg, value in args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg)] = value | |||||
| return None, None, None | |||||
| if not self.transformed: | |||||
| _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||||
| for arg, value in self._params_in_ms.items(): | |||||
| self._args_in_code[self._get_arg_name(arg)] = value | |||||
| for arg, value in self._settings_in_ms.items(): | |||||
| self._settings_in_code[arg] = value | |||||
| self.transformed = True | |||||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||||
| """ | |||||
| Froze node type and module name. | |||||
| After node_type is frozen, then the `module_name` | |||||
| will be affected when `node_type` is `class`. | |||||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||||
| Args: | |||||
| module_name: Modified module name. | |||||
| node_type (str): Node type, class of func. | |||||
| """ | |||||
| if not self._type_frozen: | |||||
| self._node_type = node_type | |||||
| self._type_frozen = True | |||||
| if not self._module_name_frozen: | |||||
| self._froze_module_name(module_name) | |||||
| def convert_successful(self): | |||||
| """ | |||||
| Whether convert successfully. | |||||
| Returns: | |||||
| bool, true or false. | |||||
| """ | |||||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||||
| return True | |||||
| return False | |||||
| @@ -42,7 +42,7 @@ class TestHierarchicalTree: | |||||
| get_raw_params.return_value = [] | get_raw_params.return_value = [] | ||||
| tree = HierarchicalTree() | tree = HierarchicalTree() | ||||
| pt_node = PyTorchGraphNode() | pt_node = PyTorchGraphNode() | ||||
| tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112)) | |||||
| tree.insert(pt_node, 'ResNet') | |||||
| assert tree.root == 'ResNet' | assert tree.root == 'ResNet' | ||||
| def test_remove(self): | def test_remove(self): | ||||
| @@ -17,11 +17,13 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| from tests.utils import mindspore | from tests.utils import mindspore | ||||
| class TestMappers: | class TestMappers: | ||||
| """Test Mappers.""" | """Test Mappers.""" | ||||
| @pytest.mark.parametrize('params', [{ | @pytest.mark.parametrize('params', [{ | ||||
| 'input': {'op_name': 'onnx::Conv', | 'input': {'op_name': 'onnx::Conv', | ||||
| 'params': {'dilations': [1, 1], | 'params': {'dilations': [1, 1], | ||||
| @@ -38,7 +40,7 @@ class TestMappers: | |||||
| 'pad_mode': '\"pad\"', | 'pad_mode': '\"pad\"', | ||||
| 'dilation': (1, 1), | 'dilation': (1, 1), | ||||
| 'group': 1}, | 'group': 1}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Conv', | 'input': {'op_name': 'onnx::Conv', | ||||
| 'params': {'dilations': [1, 1], | 'params': {'dilations': [1, 1], | ||||
| @@ -55,7 +57,7 @@ class TestMappers: | |||||
| 'pad_mode': '\"valid\"', | 'pad_mode': '\"valid\"', | ||||
| 'dilation': (1, 1), | 'dilation': (1, 1), | ||||
| 'group': 1}, | 'group': 1}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Gemm', | 'input': {'op_name': 'onnx::Gemm', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| @@ -65,7 +67,7 @@ class TestMappers: | |||||
| 'converted_params': {'in_channels': 3, | 'converted_params': {'in_channels': 3, | ||||
| 'out_channels': 10, | 'out_channels': 10, | ||||
| 'has_bias': True}, | 'has_bias': True}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::BatchNormalization', | 'input': {'op_name': 'onnx::BatchNormalization', | ||||
| 'params': {'epsilon': 1e-5, | 'params': {'epsilon': 1e-5, | ||||
| @@ -76,14 +78,14 @@ class TestMappers: | |||||
| 'converted_params': {'num_features': 6, | 'converted_params': {'num_features': 6, | ||||
| 'eps': 1e-5, | 'eps': 1e-5, | ||||
| 'momentum': 0.9}, | 'momentum': 0.9}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Relu', | 'input': {'op_name': 'onnx::Relu', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.ReLU', | 'expected_output': {'converter_name': 'nn.ReLU', | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::MaxPool', | 'input': {'op_name': 'onnx::MaxPool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| @@ -94,7 +96,7 @@ class TestMappers: | |||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| 'stride': (2, 2), | 'stride': (2, 2), | ||||
| 'pad_mode': '"same"'}, | 'pad_mode': '"same"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::AveragePool', | 'input': {'op_name': 'onnx::AveragePool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| @@ -105,7 +107,7 @@ class TestMappers: | |||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| 'stride': (2, 2), | 'stride': (2, 2), | ||||
| 'pad_mode': '"same"'}, | 'pad_mode': '"same"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::GlobalAveragePool', | 'input': {'op_name': 'onnx::GlobalAveragePool', | ||||
| 'params': {'input_shape': (1, 3, 10, 10), | 'params': {'input_shape': (1, 3, 10, 10), | ||||
| @@ -113,21 +115,21 @@ class TestMappers: | |||||
| 'weights': ''}, | 'weights': ''}, | ||||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | 'expected_output': {'converter_name': 'nn.AvgPool2d', | ||||
| 'converted_params': {'kernel_size': (10, 10)}, | 'converted_params': {'kernel_size': (10, 10)}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Flatten', | 'input': {'op_name': 'onnx::Flatten', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.Flatten', | 'expected_output': {'converter_name': 'nn.Flatten', | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Add', | 'input': {'op_name': 'onnx::Add', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'P.TensorAdd', | 'expected_output': {'converter_name': 'P.TensorAdd', | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Pad', | 'input': {'op_name': 'onnx::Pad', | ||||
| 'params': {'pads': [0, 1, 2, 3], | 'params': {'pads': [0, 1, 2, 3], | ||||
| @@ -137,7 +139,7 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | 'expected_output': {'converter_name': 'nn.Pad', | ||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | 'converted_params': {'paddings': ((0, 2), (1, 3)), | ||||
| 'mode': '\"CONSTANT\"'}, | 'mode': '\"CONSTANT\"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Pad', | 'input': {'op_name': 'onnx::Pad', | ||||
| 'params': {'pads': [0, 1, 2, 3], | 'params': {'pads': [0, 1, 2, 3], | ||||
| @@ -146,7 +148,7 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | 'expected_output': {'converter_name': 'nn.Pad', | ||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | 'converted_params': {'paddings': ((0, 2), (1, 3)), | ||||
| 'mode': '\"REFLECT\"'}, | 'mode': '\"REFLECT\"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Pad', | 'input': {'op_name': 'onnx::Pad', | ||||
| 'params': {'pads': [0, 1, 2, 3], | 'params': {'pads': [0, 1, 2, 3], | ||||
| @@ -156,7 +158,7 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | 'expected_output': {'converter_name': 'nn.Pad', | ||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | 'converted_params': {'paddings': ((0, 2), (1, 3)), | ||||
| 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, | 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Pad', | 'input': {'op_name': 'onnx::Pad', | ||||
| 'params': {'pads': [0, 1, 2, 3], | 'params': {'pads': [0, 1, 2, 3], | ||||
| @@ -165,7 +167,7 @@ class TestMappers: | |||||
| 'expected_output': {'converter_name': 'nn.Pad', | 'expected_output': {'converter_name': 'nn.Pad', | ||||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | 'converted_params': {'paddings': ((0, 2), (1, 3)), | ||||
| 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, | 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::ReduceMean', | 'input': {'op_name': 'onnx::ReduceMean', | ||||
| 'params': {'keepdims': 0, | 'params': {'keepdims': 0, | ||||
| @@ -196,14 +198,14 @@ class TestMappers: | |||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.ReLU6', | 'expected_output': {'converter_name': 'nn.ReLU6', | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Clip', | 'input': {'op_name': 'onnx::Clip', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.ReLU', | 'expected_output': {'converter_name': 'nn.ReLU', | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Clip', | 'input': {'op_name': 'onnx::Clip', | ||||
| 'params': {'max': 3, | 'params': {'max': 3, | ||||
| @@ -211,13 +213,13 @@ class TestMappers: | |||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': None, | 'expected_output': {'converter_name': None, | ||||
| 'converted_params': dict(), | 'converted_params': dict(), | ||||
| 'converted_settings': dict()} | |||||
| 'converted_settings': Setting()} | |||||
| }]) | }]) | ||||
| def test_mapper(self, params): | def test_mapper(self, params): | ||||
| """Test mapper function.""" | """Test mapper function.""" | ||||
| mapper = ONNXToMindSporeMapper() | mapper = ONNXToMindSporeMapper() | ||||
| converter_name, converted_params, converted_settings = \ | |||||
| converter_name, converted_params, converted_settings, _ = \ | |||||
| mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | ||||
| assert params['expected_output']['converter_name'] == converter_name | assert params['expected_output']['converter_name'] == converter_name | ||||
| assert params['expected_output']['converted_params'] == converted_params | assert params['expected_output']['converted_params'] == converted_params | ||||
| assert params['expected_output']['converted_settings'] == converted_settings | |||||
| assert isinstance(converted_settings, Setting) | |||||