| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Common instance and utils of graph based converter.""" | |||
| @@ -0,0 +1,218 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define CodeLine object.""" | |||
| import abc | |||
| class TrainableParams: | |||
| """Trainable parameters.""" | |||
| def __init__(self, shape, dtype, reference): | |||
| self.param_name = None | |||
| self.shape = shape | |||
| self.dtype = dtype | |||
| self.reference = reference # Weight name in global npy. | |||
| class CodeSetting: | |||
| """Code generation settings.""" | |||
| def __init__(self): | |||
| self.output_vars_suffix = [] | |||
| self.operation_input_type = None # Construct input type, tensor or list. | |||
| self.operation_extra_input = dict() # `values` in original setting dict. | |||
| self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor | |||
| class Fragment(abc.ABC): | |||
| """ | |||
| Define comment attributes of code generation. | |||
| Args: | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | |||
| self._operation = operation | |||
| self._input_shape = input_shape | |||
| self._output_shape = output_shape | |||
| self._declared_variable_name = None | |||
| self._output_var_name = list() # Output variable name(could be multi-opt). | |||
| self._operation_inputs = list() # Index indices the order of input. | |||
| self._operation_extra_inputs = settings | |||
| self._code_setting = settings | |||
| self._formal_args_list = dict() | |||
| self._actual_args_list = actual_args # Key is the param_key, value is the corresponding value. | |||
| self._node_type = "" | |||
| @property | |||
| def code_setting(self): | |||
| return self._code_setting | |||
| @property | |||
| def node_type(self): | |||
| """Node type getter.""" | |||
| return self._node_type | |||
| @node_type.setter | |||
| def node_type(self, t): | |||
| """Node type setter.""" | |||
| self._node_type = t | |||
| @property | |||
| def operation_extra_inputs(self): | |||
| """Getter of extra operation inputs.""" | |||
| return self._operation_extra_inputs | |||
| @property | |||
| def declared_var_name(self): | |||
| """Declared variable name getter.""" | |||
| return self._declared_variable_name | |||
| @declared_var_name.setter | |||
| def declared_var_name(self, var): | |||
| """Setter of declared variable name.""" | |||
| self._declared_variable_name = var | |||
| @property | |||
| def output_var_name(self) -> str: | |||
| """Getter of output variable name.""" | |||
| return ", ".join(self._output_var_name) | |||
| @output_var_name.setter | |||
| def output_var_name(self, opt_vars): | |||
| """ | |||
| Output variable name setter. | |||
| Args: | |||
| opt_vars (list[str]): Output variable name. | |||
| """ | |||
| self._output_var_name = opt_vars | |||
| @property | |||
| def operation_inputs(self): | |||
| """ | |||
| Operation getter. | |||
| Returns: | |||
| list[Fragment], list of inputs. | |||
| """ | |||
| return self._operation_inputs | |||
| def update_operation_inputs(self, ipt): | |||
| """ | |||
| Update operation inputs. | |||
| Args: | |||
| ipt (Fragment): Where input comes from. | |||
| """ | |||
| self._operation_inputs.append(ipt) | |||
| @property | |||
| def operation(self): | |||
| """ | |||
| Operation getter. | |||
| Returns: | |||
| str, operation name to be initialized. | |||
| """ | |||
| return self._operation | |||
| @operation.setter | |||
| def operation(self, op: str): | |||
| """ | |||
| Operation setter. | |||
| Args: | |||
| op (str): Operation name. | |||
| """ | |||
| self._operation = op | |||
| @property | |||
| def actual_args(self) -> dict: | |||
| """Getter of actual args.""" | |||
| return self._actual_args_list | |||
| @property | |||
| def formal_args(self) -> dict: | |||
| """Get formal args.""" | |||
| return self._formal_args_list | |||
| def update_formal_args(self, formal_args: dict): | |||
| """ | |||
| Update formal args. | |||
| Args: | |||
| formal_args (dict): To be updated args. | |||
| """ | |||
| return self._formal_args_list.update(formal_args) | |||
| @property | |||
| def input_shape(self): | |||
| return self._input_shape | |||
| @property | |||
| def output_shape(self): | |||
| return self._output_shape | |||
| class CodeFragment(Fragment): | |||
| """ | |||
| Manage the variables related with code generation. | |||
| For single operation type node, the variables in `CodeLine` stands for: | |||
| ```python | |||
| class Module(nn.Cell): | |||
| def __init__ (self, ...): | |||
| super(Module, self).__init__() | |||
| self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>, | |||
| <CodeLine.init_trainable_params>) | |||
| self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>, | |||
| dtype=<CodeLine._trainable_params[k].dtype>) | |||
| def construct(self, x, ...): | |||
| <CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>) | |||
| ... | |||
| return output | |||
| ``` | |||
| Args: | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||
| trainable_params=None): | |||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| self._trainable_params = dict() # External weights, like Matmul. | |||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | |||
| @property | |||
| def trainable_params(self): | |||
| return self._trainable_params | |||
| class ModuleFragment(Fragment): | |||
| """Manage module type code variables.""" | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape): | |||
| super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| @@ -0,0 +1,29 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define common utils.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP | |||
| def is_converted(operation: str): | |||
| """ | |||
| Whether convert successful. | |||
| Args: | |||
| operation (str): Operation name. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| return operation and SEPARATOR_IN_ONNX_OP not in operation | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Hierarchical tree module.""" | |||
| import re | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .hierarchical_tree import HierarchicalTree | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| @@ -36,7 +37,6 @@ def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||
| """ | |||
| scope_name = node.scope_name | |||
| new_name = None | |||
| parent = "" | |||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||
| match = re.match(regex, scope_name) | |||
| parent = match.group("parent") | |||
| @@ -74,12 +74,13 @@ class HierarchicalTreeFactory: | |||
| f"Cannot find {node_name}'s input shape." | |||
| log.error(err_msg) | |||
| if isinstance(node_inst, OnnxGraphNode): | |||
| node_name_with_scope = _tf_model_node_name_reformat( | |||
| node_inst, node_name) | |||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||
| node_scope_name[node_name] = node_name_with_scope | |||
| node_name = node_name_with_scope | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||
| tree.insert(node_inst, node_name) | |||
| if node_scope_name: | |||
| return tree, node_scope_name | |||
| return tree | |||
| @@ -25,17 +25,18 @@ from treelib import Tree, Node | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||
| from ..common.utils import is_converted | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig | |||
| from ..constant import SEPARATOR_IN_SCOPE | |||
| from ..constant import CodeFormatConfig | |||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| from ..report_generator import ReportGenerator | |||
| from ...common.exceptions import NodeTypeNotSupport | |||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| class HierarchicalTree(Tree): | |||
| """Define hierarchical tree.""" | |||
| @@ -46,6 +47,8 @@ class HierarchicalTree(Tree): | |||
| _root_created = False | |||
| ROOT_LEVEL = 0 | |||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| def __init__(self): | |||
| super(HierarchicalTree, self).__init__() | |||
| self._hierarchical_order = dict() | |||
| @@ -62,6 +65,7 @@ class HierarchicalTree(Tree): | |||
| self._module_vars = dict() | |||
| # scope name mapping record for easy node searching | |||
| self._scope_name_map = dict() | |||
| self.code_fragment_recorder = dict() | |||
| @property | |||
| def tree_identifier(self): | |||
| @@ -82,19 +86,15 @@ class HierarchicalTree(Tree): | |||
| return None | |||
| return self._nodes[nid] | |||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], | |||
| node_name: str, input_shape, output_shape): | |||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): | |||
| """ | |||
| Insert node into hierarchical tree. | |||
| Args: | |||
| node_name (str): Node name. | |||
| node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| node.add_input_and_output_shape(input_shape, output_shape) | |||
| scopes = node_name.split(SEPARATOR_IN_SCOPE) | |||
| for idx, scope in enumerate(scopes): | |||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | |||
| @@ -125,10 +125,9 @@ class HierarchicalTree(Tree): | |||
| tgt_node.precursor_nodes = node.precursor_nodes | |||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | |||
| else NodeType.MODULE).value | |||
| tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| tgt_node.variable_name = self._get_var_name(identifier) | |||
| self.create_node( | |||
| tag=tgt_node.tag, | |||
| tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], | |||
| identifier=identifier, | |||
| parent=parent, | |||
| data=tgt_node | |||
| @@ -276,8 +275,7 @@ class HierarchicalTree(Tree): | |||
| node.data.replace_with_arg(arg, arg) | |||
| return node | |||
| @staticmethod | |||
| def _clear_unused_args(node, used_args): | |||
| def _clear_unused_args(self, node, used_args): | |||
| """ | |||
| Clear unused args. | |||
| @@ -290,7 +288,9 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| args_in_code = list(node.data.args_in_code.keys()) | |||
| for arg in args_in_code: | |||
| ori_arg = arg.replace(f"_{node.data.variable_name}", "") | |||
| ori_arg = arg.replace( | |||
| f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" | |||
| ) | |||
| if ori_arg not in used_args: | |||
| node.data.args_in_code.pop(arg) | |||
| return node | |||
| @@ -323,6 +323,8 @@ class HierarchicalTree(Tree): | |||
| # 1. Generate args for each node in this level. | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| self._create_module_args_and_vars(node, mapper) | |||
| if depth == depths[-1]: | |||
| self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") | |||
| # Module merging based on all nodes. | |||
| self._module_merging() | |||
| @@ -345,30 +347,29 @@ class HierarchicalTree(Tree): | |||
| # then assign the created module name to current node, | |||
| # and delete unused args. | |||
| module_name = self._created_module[module_key] | |||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||
| module_name) | |||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||
| self._preprocess_node_args(nd_inst, module_key) | |||
| continue | |||
| module_name = nd_inst.data.module_name | |||
| module_name = nd_inst.tag | |||
| if node_type == NodeType.CLASS.value: | |||
| module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||
| # After node_type and module_name is frozen, | |||
| # then it's unchangeable. | |||
| module_name = self._module_mgr.get_name(module_name) | |||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||
| module_name) | |||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||
| # 3. Pre-process node args. | |||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||
| # 4. Post-process child node args. | |||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||
| self._postprocess_node_args( | |||
| self.get_node(scsr_nd_name), module_key) | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||
| # 5. Generate code. | |||
| snippets.add( | |||
| func(nd_inst, nd_inst.data.module_name, module_key)) | |||
| snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) | |||
| code_blocks.extend(snippets) | |||
| @@ -437,7 +438,7 @@ class HierarchicalTree(Tree): | |||
| module_list = [] | |||
| for node_name in node.successors(self.tree_identifier): | |||
| c_nd = self.get_node(node_name) | |||
| operator = c_nd.data.op_in_ms or c_nd.data.module_name | |||
| operator = self.code_fragment_recorder[c_nd.identifier].operation | |||
| if c_nd.data.node_type != NodeType.OPERATION.value: | |||
| hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | |||
| @@ -445,14 +446,16 @@ class HierarchicalTree(Tree): | |||
| operator = self._created_module[hash_key] | |||
| args = c_nd.data.args_in_code | |||
| if c_nd.data.node_type == NodeType.OPERATION.value and \ | |||
| not c_nd.data.convert_successful(): | |||
| if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( | |||
| self.code_fragment_recorder[c_nd.identifier].operation): | |||
| args.update({"input_shape": c_nd.data.input_shape, | |||
| "output_shape": c_nd.data.output_shape}) | |||
| # Generate code statement. | |||
| expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| expr = ", ".join( | |||
| [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" | |||
| for k, v in args.items()] | |||
| ) | |||
| code_line = f"{operator}({expr})" | |||
| module_list.append(code_line) | |||
| @@ -547,14 +550,16 @@ class HierarchicalTree(Tree): | |||
| if idx != 0: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var( | |||
| cur_nd_inst, pre_nd_inst) | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||
| # Set opt variable name. | |||
| opt_arg_in_construct = cur_nd_inst.data.opt_var_name | |||
| opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | |||
| output_var=opt_arg_in_construct) | |||
| variable_name=self.code_fragment_recorder[ | |||
| cur_nd_inst.identifier].declared_var_name, | |||
| output_var=opt_arg_in_construct, | |||
| code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) | |||
| return declare, call | |||
| @@ -588,7 +593,9 @@ class HierarchicalTree(Tree): | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| while True: | |||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| ipt_lst.append( | |||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||
| ) | |||
| break | |||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||
| if not pre_nd_name: | |||
| @@ -597,7 +604,9 @@ class HierarchicalTree(Tree): | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| ipt_lst.append( | |||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||
| ) | |||
| return ipt_lst | |||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||
| @@ -619,12 +628,11 @@ class HierarchicalTree(Tree): | |||
| cur_nd = self.get_node(p_nd[0]) | |||
| return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | |||
| def hash_key(self, node, depth: int = 0): | |||
| def hash_key(self, node): | |||
| """ | |||
| Generate hash key for each node. | |||
| Args: | |||
| depth (int): Recursion depth. | |||
| node (Node): Node. | |||
| Returns: | |||
| @@ -633,13 +641,17 @@ class HierarchicalTree(Tree): | |||
| scsr_topo_order = [] | |||
| for s in node.successors(self.tree_identifier): | |||
| cur_nd = self.get_node(s) | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]") | |||
| continue | |||
| if cur_nd.data.node_type in {NodeType.MODULE.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.CLASS.value}: | |||
| scsr_topo_order.append(self.hash_key(cur_nd, depth + 1)) | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(f"({cur_nd.data.hash_key})") | |||
| continue | |||
| raise ValueError("Current node doesn't have hash key.") | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(cur_nd.data.hash_key) | |||
| continue | |||
| unique_key = "->".join(scsr_topo_order) | |||
| node.data.hash_key = unique_key | |||
| @@ -675,12 +687,11 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| # All args and value pair in current node module. | |||
| module_args = dict() | |||
| module_settings = dict() | |||
| module_key = self.hash_key(node) | |||
| created = False | |||
| if module_key not in self._vars_mgr_in_module: | |||
| self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||
| self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR | |||
| self._module_vars[module_key] = [] | |||
| else: | |||
| created = True | |||
| @@ -688,33 +699,29 @@ class HierarchicalTree(Tree): | |||
| # Sub-modules in the module could have arg name conflicts. | |||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(successor_name) | |||
| # Generate variable name here, then | |||
| # to generate args. | |||
| # Generation of params must behind variable assigment. | |||
| if created: | |||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||
| variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name( | |||
| variable_name) | |||
| nd_inst.data.variable_name = variable_name | |||
| variable_name = nd_inst.data.op_name or nd_inst.tag | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||
| # Generation of params must behind variable assigment. | |||
| nd_inst.data.param_transform(mapper) | |||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | |||
| code_fragment.declared_var_name = variable_name | |||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| module_settings.update(nd_inst.data.settings_in_code) | |||
| if not created: | |||
| self._module_vars[module_key].append( | |||
| nd_inst.data.variable_name) | |||
| self._module_vars[module_key].append(variable_name) | |||
| node.data.args_in_code = module_args | |||
| # Collect module args of `module_key`. | |||
| if module_key not in self._merged_module: | |||
| self._merged_module[module_key] = [node.data.args_in_code] | |||
| self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] | |||
| else: | |||
| self._merged_module[module_key].append(node.data.args_in_code) | |||
| self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) | |||
| @staticmethod | |||
| def _create_operation_args(node, mapper): | |||
| @@ -63,6 +63,10 @@ START_IDX = 0 | |||
| class GlobalVarNameMgr: | |||
| """Global variable name mgr.""" | |||
| def __init__(self): | |||
| global_op_namespace.clear() | |||
| global_var_namespace.clear() | |||
| @staticmethod | |||
| def _get_name(name): | |||
| """Deal with op name.""" | |||
| @@ -87,7 +87,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| module_name = TABLE.get(op_name) | |||
| if not module_name: | |||
| return None, dict(), dict() | |||
| return None, dict(), None, dict() | |||
| pos = module_name.rfind(".") | |||
| try: | |||
| @@ -101,7 +101,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| # If mapper can not be found, then skip it. | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict(), dict() | |||
| return None, dict(), None, dict() | |||
| try: | |||
| converter_name = op_name_converter( | |||
| @@ -110,13 +110,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| converted_weights = weights_converter( | |||
| weights=weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| converted_settings = settings_converter(params=params) | |||
| converted_settings = settings_converter(params=params, weights=weights) | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict(), dict() | |||
| return None, dict(), None, dict() | |||
| return converter_name, converted_params, converted_settings | |||
| return converter_name, converted_params, converted_settings, converted_weights | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Operation mapping setting.""" | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||
| Tensor = namedtuple("Tensor", ["shape", "dtype", "reference"]) | |||
| Setting = namedtuple("Setting", ["opt_vars_suffix", | |||
| "op_ipt_type", | |||
| "op_extra_input", | |||
| "op_extra_tensor"]) | |||
| Setting.__new__.__defaults__ = ("_opt", InputType.TENSOR.value, dict(), None) | |||
| def get_dtype(tensor: np.ndarray): | |||
| """Get tensor dtype.""" | |||
| if tensor.dtype == np.float16: | |||
| return "mindspore.float16" | |||
| return "mindspore.float32" | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class BatchNormMapper(ONNXToMindSporeMapper): | |||
| @@ -39,4 +40,4 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -16,6 +16,7 @@ | |||
| import re | |||
| import numpy as np | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| def _convert_padding(**kwargs): | |||
| @@ -35,6 +36,7 @@ def _convert_padding(**kwargs): | |||
| class ConvMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| @staticmethod | |||
| def convert_params_torch(**kwargs): | |||
| """Convert params from PyTorch to MindSpore""" | |||
| @@ -148,4 +150,4 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class DenseMapper(ONNXToMindSporeMapper): | |||
| @@ -41,4 +42,4 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class FlattenMapper(ONNXToMindSporeMapper): | |||
| @@ -33,4 +34,4 @@ class FlattenMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| @@ -25,8 +26,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| op_name = 'nn.AvgPool{}d' | |||
| else: | |||
| op_name = 'nn.MaxPool{}d' | |||
| dim = 1 if len(kwargs['params']['input_shape']) == 3\ | |||
| else 2 | |||
| dim = 1 if len(kwargs['params']['input_shape']) == 3 else 2 | |||
| return op_name.format(dim) | |||
| @staticmethod | |||
| @@ -49,4 +49,4 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting, Tensor, get_dtype | |||
| class MatMulMapper(ONNXToMindSporeMapper): | |||
| @@ -33,4 +34,12 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return Setting() | |||
| tensor, ref = None, "" | |||
| for t_name, t_value in weights.items(): | |||
| tensor = t_value | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| def _padding_format_convert(padding: list): | |||
| @@ -77,4 +78,4 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class PoolMapper(ONNXToMindSporeMapper): | |||
| @@ -49,4 +50,4 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class ReLUMapper(ONNXToMindSporeMapper): | |||
| @@ -45,4 +46,4 @@ class ReLUMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class SoftmaxMapper(ONNXToMindSporeMapper): | |||
| @@ -37,4 +38,4 @@ class SoftmaxMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| return Setting() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting, Tensor, get_dtype | |||
| class AddMapper(ONNXToMindSporeMapper): | |||
| @@ -33,4 +34,12 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return Setting() | |||
| tensor, ref = None, "" | |||
| for t_name, t_value in weights.items(): | |||
| tensor = t_value | |||
| ref = t_name | |||
| return Setting(op_extra_tensor=Tensor(shape=tensor.shape, | |||
| dtype=get_dtype(tensor), reference=ref)) | |||
| @@ -15,6 +15,7 @@ | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import InputType | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class ConcatMapper(ONNXToMindSporeMapper): | |||
| @@ -36,4 +37,4 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| input_type = InputType.LIST.value | |||
| return {'input_type': input_type} | |||
| return Setting(op_ipt_type=input_type) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| @@ -40,4 +41,4 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): | |||
| axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) | |||
| else: | |||
| axis = tuple() | |||
| return {'values': {'axis': axis}} | |||
| return Setting(op_extra_input={'axis': axis}) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class TransposeMapper(ONNXToMindSporeMapper): | |||
| @@ -40,4 +41,4 @@ class TransposeMapper(ONNXToMindSporeMapper): | |||
| perm = tuple(perm) | |||
| converted_params['input_perm'] = perm | |||
| return {'values': converted_params} | |||
| return Setting(op_extra_input=converted_params) | |||
| @@ -15,10 +15,13 @@ | |||
| """Define graph entity.""" | |||
| import abc | |||
| from collections import OrderedDict | |||
| from copy import deepcopy | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..constant import SEPARATOR_IN_ONNX_OP | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..constant import NodeType, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| class GraphParser(metaclass=abc.ABCMeta): | |||
| @@ -287,26 +290,10 @@ class GraphNode(abc.ABC): | |||
| self._op_params = dict() | |||
| self._scope_name = None | |||
| self._op_shape = None | |||
| # Operation in mindspore. | |||
| self._op_in_ms = None | |||
| # Params in mindspore. | |||
| self._params_in_ms = dict() | |||
| # Settings in mindspore. | |||
| self._settings_in_ms = dict() | |||
| # Node type of current node, e.g. class, module, operation. | |||
| self._node_type = None | |||
| # Tag name on tree. | |||
| self._tag_on_tree = None | |||
| # Function, class or operation needed args. | |||
| self._args_in_code = dict() | |||
| # Operation needed settings. | |||
| self._settings_in_code = dict() | |||
| # Variable name declared in init block. | |||
| self._variable_name = None | |||
| # Output variable name declared in construct block. | |||
| self._opt_var_name = None | |||
| # Function or class name in code. | |||
| self._module_name = None | |||
| # Unique key of node. | |||
| self._hash_key = None | |||
| # Input shape of current op. | |||
| @@ -317,37 +304,18 @@ class GraphNode(abc.ABC): | |||
| self._weight = None | |||
| @property | |||
| def opt_var_name(self): | |||
| def weight(self): | |||
| return self._weight | |||
| @staticmethod | |||
| def get_opt_var_name(variable_name): | |||
| """ | |||
| Output variable name. | |||
| Returns: | |||
| str, variable name. | |||
| """ | |||
| return f"{self.variable_name}_opt" | |||
| @opt_var_name.setter | |||
| def opt_var_name(self, v): | |||
| """ | |||
| Set variable name. | |||
| Args: | |||
| v (str): Name. | |||
| """ | |||
| self._opt_var_name = v | |||
| @property | |||
| def op_in_ms(self): | |||
| """ | |||
| Operation in mindspore. | |||
| Returns: | |||
| str, operation name. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms: | |||
| return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| return self._op_in_ms | |||
| return f"{variable_name}_opt" | |||
| @property | |||
| def args_in_code(self): | |||
| @@ -370,27 +338,6 @@ class GraphNode(abc.ABC): | |||
| """ | |||
| self._args_in_code = args | |||
| @property | |||
| def settings_in_code(self): | |||
| """ | |||
| Settings in code. | |||
| Returns: | |||
| dict, settings. | |||
| """ | |||
| return self._settings_in_code | |||
| @settings_in_code.setter | |||
| def settings_in_code(self, settings): | |||
| """ | |||
| Settings in code. | |||
| Args: | |||
| settings(dict): Settings. | |||
| """ | |||
| self._settings_in_code = settings | |||
| @property | |||
| def input_shape(self): | |||
| """ | |||
| @@ -411,16 +358,6 @@ class GraphNode(abc.ABC): | |||
| """ | |||
| return self._opt_shape | |||
| @property | |||
| def tag(self): | |||
| """Tag on hierarchical tree.""" | |||
| return self._tag_on_tree | |||
| @tag.setter | |||
| def tag(self, t): | |||
| """Tag on hierarchical tree.""" | |||
| self._tag_on_tree = t | |||
| def is_empty(self): | |||
| """ | |||
| Whether is empty. | |||
| @@ -536,7 +473,7 @@ class GraphNode(abc.ABC): | |||
| """Replace actual parameter with formal parameter.""" | |||
| @abc.abstractmethod | |||
| def _get_arg_name(self, arg): | |||
| def _get_arg_name(self, arg, variable_name): | |||
| """Get arg name for func or class.""" | |||
| @abc.abstractmethod | |||
| @@ -553,13 +490,8 @@ class GraphNode(abc.ABC): | |||
| def real_name(self, **kwargs): | |||
| """Setter of `real_name`.""" | |||
| @property | |||
| @abc.abstractmethod | |||
| def variable_name(self): | |||
| """Getter of `variable_name`.""" | |||
| @abc.abstractmethod | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||
| """Graph node to MindSpore code.""" | |||
| @abc.abstractmethod | |||
| @@ -570,40 +502,86 @@ class GraphNode(abc.ABC): | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| """Add the node input shape.""" | |||
| @abc.abstractmethod | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """Make node_type can not be changed.""" | |||
| @staticmethod | |||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||
| """ | |||
| Generate input with args and settings in construct. | |||
| @abc.abstractmethod | |||
| def convert_successful(self): | |||
| """Whether convert successful.""" | |||
| Args: | |||
| ipt_args_in_construct (str): Input args in construct. | |||
| settings (Setting): Settings in operator. | |||
| Returns: | |||
| str, args of each node in generated construct statement. | |||
| """ | |||
| if settings and settings.op_ipt_type: | |||
| input_type = settings.op_ipt_type | |||
| if input_type == InputType.TENSOR.value: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||
| else: | |||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if settings and settings.op_extra_input: | |||
| settings_value = settings.op_extra_input | |||
| if settings_value: | |||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||
| def param_transform(self, mapper: Mapper): | |||
| return ipt_args_settings_in_construct | |||
| def param_transform(self, mapper: Mapper, variable_name): | |||
| """ | |||
| Transform param in pytorch operation into mindspore. | |||
| Transform param in PyTorch operation into MindSpore. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| mapper (ONNXToMindSporeMapper): Mapper between onnx operation | |||
| and mindspore. | |||
| and MindSpore. | |||
| Returns: | |||
| dict, transformed params. | |||
| """ | |||
| import copy | |||
| params = copy.deepcopy(self._op_params) | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||
| return CodeFragment(operation="", actual_args=args, settings=None, | |||
| input_shape=self.input_shape, output_shape=self.output_shape) | |||
| if self.transformed: | |||
| raise ValueError("Already transformed.") | |||
| params = deepcopy(self._op_params) | |||
| params.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, | |||
| params=params, | |||
| weights=self._weight) | |||
| if op_name_in_mindspore: | |||
| self._op_in_ms = op_name_in_mindspore | |||
| self._params_in_ms = ms_params | |||
| self._settings_in_ms = ms_settings | |||
| ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, | |||
| params=params, | |||
| weights=self._weight) | |||
| if ms_op: | |||
| code_fragment = CodeFragment(operation=ms_op, | |||
| actual_args=ms_params, | |||
| settings=ms_settings, | |||
| input_shape=self.input_shape, | |||
| output_shape=self.output_shape, | |||
| trainable_params=ms_weights) | |||
| else: | |||
| self._op_in_ms = self._op_name | |||
| self._params_in_ms = self._op_params | |||
| self._settings_in_ms = dict() | |||
| code_fragment = CodeFragment(operation=self._op_name, | |||
| actual_args=self._op_params, | |||
| settings=None, | |||
| input_shape=self.input_shape, | |||
| output_shape=self.output_shape, | |||
| trainable_params=self._weight) | |||
| for arg, value in code_fragment.actual_args.items(): | |||
| self._args_in_code[self._get_arg_name(arg, variable_name)] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| return code_fragment | |||
| @@ -38,7 +38,6 @@ class PyTorchGraphParser(GraphParser): | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| try: | |||
| @@ -21,24 +21,18 @@ from ..constant import SEPARATOR_IN_SCOPE, NodeType | |||
| class InputNode(GraphNode): | |||
| """ | |||
| Pytorch Input Node. | |||
| PyTorch Input Node. | |||
| Args: | |||
| input_shape: Input shape of module. | |||
| """ | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successful. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| return False | |||
| def _get_arg_name(self, arg, variable_name): | |||
| raise NotImplementedError() | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| pass | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||
| raise NotImplementedError() | |||
| def _get_raw_params(self, node): | |||
| pass | |||
| @@ -56,9 +50,6 @@ class InputNode(GraphNode): | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| pass | |||
| def _get_arg_name(self, arg): | |||
| pass | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| pass | |||
| @@ -116,15 +107,8 @@ class InputNode(GraphNode): | |||
| def real_name(self): | |||
| return | |||
| @property | |||
| def variable_name(self): | |||
| return | |||
| def to_ir(self): | |||
| """ | |||
| No need to implement for now. | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| raise NotImplementedError() | |||
| @@ -22,7 +22,6 @@ from .onnx_graph_node import OnnxGraphNode | |||
| from .graph_parser import TFGraphParser | |||
| from .onnx_utils import OnnxDataLoader | |||
| NONE_SCOPE_OP = { | |||
| "onnx::Add": "Add", | |||
| "onnx::Flatten": "Flatten", | |||
| @@ -13,14 +13,13 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX graph node.""" | |||
| from importlib import import_module | |||
| from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..common.utils import is_converted | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| class OnnxGraphNode(GraphNode): | |||
| @@ -39,16 +38,13 @@ class OnnxGraphNode(GraphNode): | |||
| self._op_params = self._get_raw_params(node.raw_node) if node else None | |||
| self._op_name = "onnx::" + node.op_type if node else None | |||
| self._scope_name = node.scope_name if node else None | |||
| self._opt_var_name = None | |||
| self._variable_name = self._extract_var_name(self._scope_name) | |||
| self._module_name = None | |||
| self._weight = weight | |||
| def clear_args_of_declaration(self): | |||
| """Clear `self._args_in_code`.""" | |||
| self._args_in_code = dict() | |||
| def _get_arg_name(self, arg): | |||
| def _get_arg_name(self, arg, variable_name): | |||
| """ | |||
| Get arg name. | |||
| @@ -58,7 +54,7 @@ class OnnxGraphNode(GraphNode): | |||
| Returns: | |||
| str, arg name in function or class declaration. | |||
| """ | |||
| return f"{arg}_{self._variable_name}" | |||
| return f"{arg}_{variable_name}" | |||
| @property | |||
| def hash_key(self): | |||
| @@ -84,51 +80,6 @@ class OnnxGraphNode(GraphNode): | |||
| """ | |||
| self._hash_key = h | |||
| @property | |||
| def variable_name(self): | |||
| """ | |||
| Variable name. | |||
| Returns: | |||
| str, variable name declared in init. | |||
| """ | |||
| return self._variable_name | |||
| @variable_name.setter | |||
| def variable_name(self, v): | |||
| """ | |||
| Setter of variable name. | |||
| Args: | |||
| v (str): Variable name. | |||
| """ | |||
| self._variable_name = v | |||
| @property | |||
| def module_name(self): | |||
| """ | |||
| Module name. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| module_name = self.tag | |||
| return module_name | |||
| return self._module_name | |||
| def _froze_module_name(self, m): | |||
| """ | |||
| Once module_name is set, then it's unchangeable. | |||
| Args: | |||
| m (str): Module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| self._module_name = m | |||
| self._module_name_frozen = True | |||
| @property | |||
| def op_name(self): | |||
| """ | |||
| @@ -154,15 +105,13 @@ class OnnxGraphNode(GraphNode): | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args): | |||
| def _add_tensor_args_to_code(self, op_name: str, settings, declare, args, variable_name): | |||
| """ | |||
| Add nn used tensors to args in init and construct blocks. | |||
| Args: | |||
| op_name (str): Add the tensor to args if the current node has this | |||
| op_name. | |||
| t_identifier (str): The unique string appeared in the target tensor | |||
| name. | |||
| op_name. | |||
| declare (str): Declare statement generated in to_code(). | |||
| args (str): Args statement generated in to_code(). | |||
| @@ -172,103 +121,68 @@ class OnnxGraphNode(GraphNode): | |||
| """ | |||
| if not self._op_name == op_name: | |||
| return declare, args | |||
| declare_list = [] | |||
| tensor = None | |||
| # find target tensor | |||
| for t_name, t_value in self._weight.items(): | |||
| if t_identifier in t_name: | |||
| tensor = t_value | |||
| break | |||
| if tensor is None: | |||
| if not settings or not settings.op_extra_tensor: | |||
| return declare, args | |||
| declare_list.append(declare) | |||
| declare_t = f"self.{self._variable_name}_w = Tensor(" \ | |||
| f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)" | |||
| declare_list = [declare] | |||
| declare_t = f"self.{variable_name}_w = Tensor(" \ | |||
| f"np.random.uniform(0, 1, {str(settings.op_extra_tensor.shape)}), " \ | |||
| f"{settings.op_extra_tensor.dtype})" | |||
| declare_list.append(declare_t) | |||
| args += f", self.{self._variable_name}_w" | |||
| args += f", self.{variable_name}_w" | |||
| return declare_list, args | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, | |||
| code_fragment): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (str): Output variable name in construct. | |||
| code_fragment (CodeFragment): CodeFragment instance. | |||
| Returns: | |||
| Union[str, str], declare in init and call in construct. | |||
| """ | |||
| operator = self.op_in_ms or self.module_name | |||
| self._opt_var_name = output_var | |||
| operator = code_fragment.operation | |||
| args = self.args_in_code | |||
| settings = self.settings_in_code | |||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||
| settings = code_fragment.code_setting | |||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = \ | |||
| self._generate_ipt_args_settings_in_construct( | |||
| ipt_args_in_construct, | |||
| settings) | |||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||
| ipt_args_in_construct, settings) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| if SEPARATOR_IN_ONNX_OP in operator: | |||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| declare = f"self.{variable_name} = {operator}({expr})" | |||
| # Extra Tensor generator for nn.MatMul | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | |||
| 'onnx::MatMul', settings, declare, ipt_args_settings_in_construct, variable_name) | |||
| # Extra Tensor generator for onnx::Add | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||
| 'onnx::Add', settings, declare, ipt_args_settings_in_construct, variable_name) | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||
| @staticmethod | |||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||
| """ | |||
| Generate input with args and settings in construct. | |||
| Args: | |||
| ipt_args_in_construct(str): Input args in construct. | |||
| settings(dict): Settings in operator. | |||
| Returns: | |||
| str, args of each node in generated construct statement. | |||
| """ | |||
| if settings.get('input_type'): | |||
| input_type = settings['input_type'] | |||
| if input_type == InputType.TENSOR.value: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||
| else: | |||
| raise NodeInputTypeNotSupport( | |||
| f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if settings.get('values'): | |||
| settings_value = settings['values'] | |||
| if settings_value: | |||
| settings_in_construct = ', '.join( | |||
| [f"{setting_val}" for _, setting_val in settings_value.items()]) | |||
| ipt_args_settings_in_construct = ', '.join( | |||
| (ipt_args_settings_in_construct, settings_in_construct)) | |||
| return ipt_args_settings_in_construct | |||
| def to_ir(self): | |||
| """No need to implement for now.""" | |||
| raise NotImplementedError | |||
| @@ -284,7 +198,7 @@ class OnnxGraphNode(GraphNode): | |||
| Returns: | |||
| dict, raw params. | |||
| """ | |||
| import onnx | |||
| onnx = import_module("onnx") | |||
| raw_params = dict() | |||
| @@ -318,62 +232,3 @@ class OnnxGraphNode(GraphNode): | |||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||
| RIGHT_BUCKET, "") | |||
| return var | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform tensorflow params into mindspore. | |||
| Args: | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| return None, None | |||
| if not self.transformed: | |||
| _, _, _ = super(OnnxGraphNode, self).param_transform(mapper) | |||
| for arg, value in self._params_in_ms.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| for arg, value in self._settings_in_ms.items(): | |||
| self._settings_in_code[arg] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """ | |||
| Froze node type and module name. | |||
| After node_type is frozen, then the `module_name` | |||
| will be affected when `node_type` is `class`. | |||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||
| Args: | |||
| module_name: Modified module name. | |||
| node_type (str): Node type, class of func. | |||
| """ | |||
| if not self._type_frozen: | |||
| self._node_type = node_type | |||
| self._type_frozen = True | |||
| if not self._module_name_frozen: | |||
| self._froze_module_name(module_name) | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successfully. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||
| return True | |||
| return False | |||
| @@ -87,7 +87,8 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None | |||
| inputs_as_nchw=None | |||
| ) | |||
| opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map') | |||
| opt_map.pop(('Conv', 'BatchNormalization')) | |||
| if ('Conv', 'BatchNormalization') in opt_map: | |||
| opt_map.pop(('Conv', 'BatchNormalization')) | |||
| onnx_graph = optimizer.optimize_graph(g) | |||
| model_proto = onnx_graph.make_model("converted from {}".format(model_path)) | |||
| @@ -228,8 +229,7 @@ class OnnxNode(BaseNode): | |||
| """ | |||
| def __init__(self, raw_node): | |||
| super(OnnxNode, self).__init__( | |||
| node_name=raw_node.name, op_type=raw_node.op_type) | |||
| super(OnnxNode, self).__init__(node_name=raw_node.name, op_type=raw_node.op_type) | |||
| self.raw_node = raw_node | |||
| self.params = ParamsAttribute(raw_node.attribute, raw_node) | |||
| self.scope_name = None | |||
| @@ -99,8 +99,8 @@ class PyTorchGraph(Graph): | |||
| for item in input_shape: | |||
| if not isinstance(item, int): | |||
| err_msg = f"Only support model with one input now, " \ | |||
| f"and each shape value in `input_shape` should be int." | |||
| err_msg = "Only support model with one input now, " \ | |||
| "and each shape value in `input_shape` should be int." | |||
| log.error(err_msg) | |||
| raise ValueError(err_msg) | |||
| @@ -13,14 +13,11 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..common.utils import is_converted | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| SEPARATOR_IN_ONNX_OP | |||
| class PyTorchGraphNode(GraphNode): | |||
| @@ -40,9 +37,6 @@ class PyTorchGraphNode(GraphNode): | |||
| self._op_params = self._get_raw_params(node) | |||
| self._op_name = node.kind() if node else None | |||
| self._scope_name = node.scopeName() if node else None | |||
| self._opt_var_name = None | |||
| self._variable_name = self._extract_var_name(self._scope_name) | |||
| self._module_name = None | |||
| self._weight = weight | |||
| def clear_args_of_declaration(self): | |||
| @@ -51,7 +45,7 @@ class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| self._args_in_code = dict() | |||
| def _get_arg_name(self, arg): | |||
| def _get_arg_name(self, arg, variable_name): | |||
| """ | |||
| Get arg name. | |||
| @@ -61,7 +55,7 @@ class PyTorchGraphNode(GraphNode): | |||
| Returns: | |||
| str, arg name in function or class declaration. | |||
| """ | |||
| return f"{arg}_{self._variable_name}" | |||
| return f"{arg}_{variable_name}" | |||
| @property | |||
| def hash_key(self): | |||
| @@ -88,53 +82,6 @@ class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| self._hash_key = h | |||
| @property | |||
| def variable_name(self): | |||
| """ | |||
| Variable name. | |||
| Returns: | |||
| str, variable name declared in init. | |||
| """ | |||
| return self._variable_name | |||
| @variable_name.setter | |||
| def variable_name(self, v): | |||
| """ | |||
| Setter of variable name. | |||
| Args: | |||
| v (str): Variable name. | |||
| """ | |||
| self._variable_name = v | |||
| @property | |||
| def module_name(self): | |||
| """ | |||
| Module name. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| module_name = self.tag | |||
| return module_name | |||
| return self._module_name | |||
| def _froze_module_name(self, m): | |||
| """ | |||
| Once module_name is set, then it's unchangeable. | |||
| Args: | |||
| m (str): Module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| self._module_name = m | |||
| self._module_name_frozen = True | |||
| @property | |||
| def op_name(self): | |||
| """ | |||
| @@ -172,72 +119,47 @@ class PyTorchGraphNode(GraphNode): | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (str): Output variable name in construct. | |||
| code_fragment (CodeFragment): CodeFragment instance. | |||
| Returns: | |||
| Union[str, str], declare in init and call in construct. | |||
| """ | |||
| operator = self.op_in_ms or self.module_name | |||
| self._opt_var_name = output_var | |||
| operator = code_fragment.operation | |||
| args = self.args_in_code | |||
| settings = self.settings_in_code | |||
| settings = code_fragment.code_setting | |||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, | |||
| settings) | |||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||
| ipt_args_in_construct, settings) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||
| @staticmethod | |||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||
| """ | |||
| Generate input with args and settings in construct. | |||
| Args: | |||
| ipt_args_in_construct(str): input args in construct. | |||
| settings(dict): settings in operator. | |||
| """ | |||
| if settings.get('input_type'): | |||
| input_type = settings['input_type'] | |||
| if input_type == InputType.TENSOR.value: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||
| else: | |||
| raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if SEPARATOR_IN_ONNX_OP in operator: | |||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| if settings.get('values'): | |||
| settings_value = settings['values'] | |||
| if settings_value: | |||
| settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) | |||
| ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) | |||
| declare = f"self.{variable_name} = {operator}({expr})" | |||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||
| return ipt_args_settings_in_construct | |||
| return declare, call | |||
| def to_ir(self): | |||
| """ | |||
| @@ -288,62 +210,3 @@ class PyTorchGraphNode(GraphNode): | |||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||
| RIGHT_BUCKET, "") | |||
| return var | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform torch params into mindspore. | |||
| Args: | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| return None, None, None | |||
| if not self.transformed: | |||
| _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||
| for arg, value in self._params_in_ms.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| for arg, value in self._settings_in_ms.items(): | |||
| self._settings_in_code[arg] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """ | |||
| Froze node type and module name. | |||
| After node_type is frozen, then the `module_name` | |||
| will be affected when `node_type` is `class`. | |||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||
| Args: | |||
| module_name: Modified module name. | |||
| node_type (str): Node type, class of func. | |||
| """ | |||
| if not self._type_frozen: | |||
| self._node_type = node_type | |||
| self._type_frozen = True | |||
| if not self._module_name_frozen: | |||
| self._froze_module_name(module_name) | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successfully. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||
| return True | |||
| return False | |||
| @@ -42,7 +42,7 @@ class TestHierarchicalTree: | |||
| get_raw_params.return_value = [] | |||
| tree = HierarchicalTree() | |||
| pt_node = PyTorchGraphNode() | |||
| tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112)) | |||
| tree.insert(pt_node, 'ResNet') | |||
| assert tree.root == 'ResNet' | |||
| def test_remove(self): | |||
| @@ -17,11 +17,13 @@ import numpy as np | |||
| import pytest | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from tests.utils import mindspore | |||
| class TestMappers: | |||
| """Test Mappers.""" | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'input': {'op_name': 'onnx::Conv', | |||
| 'params': {'dilations': [1, 1], | |||
| @@ -38,7 +40,7 @@ class TestMappers: | |||
| 'pad_mode': '\"pad\"', | |||
| 'dilation': (1, 1), | |||
| 'group': 1}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Conv', | |||
| 'params': {'dilations': [1, 1], | |||
| @@ -55,7 +57,7 @@ class TestMappers: | |||
| 'pad_mode': '\"valid\"', | |||
| 'dilation': (1, 1), | |||
| 'group': 1}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Gemm', | |||
| 'params': dict(), | |||
| @@ -65,7 +67,7 @@ class TestMappers: | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 10, | |||
| 'has_bias': True}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::BatchNormalization', | |||
| 'params': {'epsilon': 1e-5, | |||
| @@ -76,14 +78,14 @@ class TestMappers: | |||
| 'converted_params': {'num_features': 6, | |||
| 'eps': 1e-5, | |||
| 'momentum': 0.9}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Relu', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::MaxPool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| @@ -94,7 +96,7 @@ class TestMappers: | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| 'stride': (2, 2), | |||
| 'pad_mode': '"same"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::AveragePool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| @@ -105,7 +107,7 @@ class TestMappers: | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| 'stride': (2, 2), | |||
| 'pad_mode': '"same"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::GlobalAveragePool', | |||
| 'params': {'input_shape': (1, 3, 10, 10), | |||
| @@ -113,21 +115,21 @@ class TestMappers: | |||
| 'weights': ''}, | |||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | |||
| 'converted_params': {'kernel_size': (10, 10)}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Flatten', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.Flatten', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Add', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'P.TensorAdd', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| @@ -137,7 +139,7 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '\"CONSTANT\"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| @@ -146,7 +148,7 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '\"REFLECT\"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| @@ -156,7 +158,7 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Pad', | |||
| 'params': {'pads': [0, 1, 2, 3], | |||
| @@ -165,7 +167,7 @@ class TestMappers: | |||
| 'expected_output': {'converter_name': 'nn.Pad', | |||
| 'converted_params': {'paddings': ((0, 2), (1, 3)), | |||
| 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::ReduceMean', | |||
| 'params': {'keepdims': 0, | |||
| @@ -196,14 +198,14 @@ class TestMappers: | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU6', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Clip', | |||
| 'params': dict(), | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.ReLU', | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Clip', | |||
| 'params': {'max': 3, | |||
| @@ -211,13 +213,13 @@ class TestMappers: | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': None, | |||
| 'converted_params': dict(), | |||
| 'converted_settings': dict()} | |||
| 'converted_settings': Setting()} | |||
| }]) | |||
| def test_mapper(self, params): | |||
| """Test mapper function.""" | |||
| mapper = ONNXToMindSporeMapper() | |||
| converter_name, converted_params, converted_settings = \ | |||
| converter_name, converted_params, converted_settings, _ = \ | |||
| mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) | |||
| assert params['expected_output']['converter_name'] == converter_name | |||
| assert params['expected_output']['converted_params'] == converted_params | |||
| assert params['expected_output']['converted_settings'] == converted_settings | |||
| assert isinstance(converted_settings, Setting) | |||