From: @moran3 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -194,6 +194,9 @@ def convert_bytes_string_to_string(bytes_str): | |||
| def get_framework_type(model_path): | |||
| """Get framework type.""" | |||
| if model_path.endswith('.onnx'): | |||
| return FrameworkType.PYTORCH.value | |||
| try: | |||
| with open(model_path, 'rb') as f: | |||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,10 +24,12 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver | |||
| save_code_file_and_report, get_framework_type | |||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | |||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | |||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -62,6 +64,7 @@ def torch_installation_validation(func): | |||
| """ | |||
| def _f(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): | |||
| @@ -93,6 +96,7 @@ def torch_installation_validation(func): | |||
| sys.exit(0) | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes, | |||
| output_folder=output_folder, report_folder=report_folder) | |||
| return _f | |||
| @@ -182,6 +186,7 @@ def _extract_model_name(model_path): | |||
| @SourceFilesSaveError.uniform_catcher() | |||
| @GeneratorError.uniform_catcher() | |||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| PyTorch to MindSpore based on Graph. | |||
| @@ -189,26 +194,18 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| Args: | |||
| graph_path (str): Graph file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| input_nodes (str): Input node(s) of the model. | |||
| output_nodes (str): Output node(s) of the model. | |||
| output_folder (str): Output folder. | |||
| report_folder (str): Report output folder path. | |||
| """ | |||
| third_party_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||
| hierarchical_tree_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | |||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||
| model_name = _extract_model_name(graph_path) | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| model_name=model_name, | |||
| report_folder=report_folder) | |||
| code_fragments = generator_inst.generate() | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| @tf_installation_validation | |||
| @@ -230,18 +227,13 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_nodes(str): Output node(s) of the model. | |||
| output_folder(str): Output folder. | |||
| report_folder(str): Report output folder path. | |||
| """ | |||
| third_party_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||
| batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), | |||
| "batch_add_nodes") | |||
| # Close unnecessary log. | |||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||
| model_name = _extract_model_name(graph_path) | |||
| code_fragments = generator_inst.generate() | |||
| @@ -255,7 +247,6 @@ def main_graph_base_converter(file_config): | |||
| Args: | |||
| file_config (dict): The config of file which to convert. | |||
| """ | |||
| graph_path = file_config['model_file'] | |||
| frame_type = get_framework_type(graph_path) | |||
| @@ -263,8 +254,12 @@ def main_graph_base_converter(file_config): | |||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | |||
| if frame_type == FrameworkType.PYTORCH.value: | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||
| sample_shape=file_config['shape'], | |||
| input_nodes=file_config['input_nodes'], | |||
| output_nodes=file_config['output_nodes'], | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| elif frame_type == FrameworkType.TENSORFLOW.value: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -213,10 +213,11 @@ class ArgsTranslationHelper: | |||
| Returns: | |||
| list, name of args to be formal. | |||
| """ | |||
| ret = list() | |||
| if len(args_translators) < 2: | |||
| # only one args_translator provided, no formal args. | |||
| return None | |||
| ret = [] | |||
| return ret | |||
| base_args_t = args_translators[0] | |||
| for arg_name, arg_val in base_args_t.actual_args.items(): | |||
| for args_t in args_translators[1:]: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,7 +24,7 @@ from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ...common.exceptions import GeneratorError | |||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||
| from ..common.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||
| from ..report_generator import ReportGenerator | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -22,7 +22,7 @@ from ..common.utils import get_dict_key_by_value | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import ModuleFragment | |||
| from ..common.global_context import GlobalContext | |||
| from ..hierarchical_tree.name_mgr import LocalVarNameMgr | |||
| from ..common.name_mgr import LocalVarNameMgr | |||
| class ModuleStruct: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,7 +18,6 @@ from collections import OrderedDict | |||
| from .scope_utils import Scope | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import InputType | |||
| @@ -110,11 +109,6 @@ class NodeStruct: | |||
| self.graph_node_ref = gn | |||
| self.scope_name = gn.scope_name | |||
| def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): | |||
| """Update basic info from PyTorchGraphNode.""" | |||
| self.node_type = "PyTorchGraphNode" | |||
| self._update_basics_from_gn(gn) | |||
| def _update_from_onnx_gn(self, gn: OnnxGraphNode): | |||
| """Update basic info from OnnxGraphNode.""" | |||
| self.node_type = "OnnxGraphNode" | |||
| @@ -177,9 +171,8 @@ class NodeStruct: | |||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||
| force_ready (bool): Force this NodeStruct is ready to generate. | |||
| """ | |||
| if isinstance(arg, PyTorchGraphNode): | |||
| self._update_from_pytorch_gn(arg) | |||
| elif isinstance(arg, OnnxGraphNode): | |||
| if isinstance(arg, OnnxGraphNode): | |||
| self._update_from_onnx_gn(arg) | |||
| elif isinstance(arg, (dict, OrderedDict)): | |||
| self._update_from_mapper(arg) | |||
| @@ -246,7 +239,6 @@ class NodeStruct: | |||
| """Return the output variable name of current node.""" | |||
| return "{}_opt".format(self.ms_var_name).lower() | |||
| @property | |||
| def args_translator(self): | |||
| """Return the args translator of this Node.""" | |||
| @@ -1,89 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Hierarchical tree module.""" | |||
| __all__ = ["HierarchicalTreeFactory"] | |||
| import re | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .hierarchical_tree import HierarchicalTree | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError | |||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||
| """ | |||
| Rename the node name by combining scope name and its original name. | |||
| Args: | |||
| node (OnnxGraphNode): OnnxGraphNode instance. | |||
| node_name (str): node name saved in Graph. | |||
| Returns: | |||
| str, re-formatted node name. | |||
| """ | |||
| scope_name = node.scope_name | |||
| new_name = None | |||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||
| match = re.match(regex, scope_name) | |||
| parent = match.group("parent") | |||
| node_name = '$' + node_name.replace('/', '::') + '$' | |||
| if scope_name: | |||
| new_name = parent + node_name | |||
| if new_name: | |||
| return new_name | |||
| return node_name | |||
| class HierarchicalTreeFactory: | |||
| """Hierarchical tree factory.""" | |||
| @classmethod | |||
| @TreeNodeInsertError.check_except("Tree node inserts failed.") | |||
| def create(cls, graph): | |||
| """ | |||
| Factory method of hierarchical tree. | |||
| Args: | |||
| graph: Graph obj. | |||
| Returns: | |||
| HierarchicalTree, tree. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_scope_name = dict() | |||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||
| node_inst = graph.get_node(node_name) | |||
| node_input = graph.get_input_shape(node_name) | |||
| node_output = graph.get_output_shape(node_name) | |||
| if node_input != 0 and not node_input: | |||
| err_msg = f"This model is not supported now. " \ | |||
| f"Cannot find {node_name}'s input shape." | |||
| error = NodeInputMissingError(err_msg) | |||
| log.error(str(error)) | |||
| raise error | |||
| if isinstance(node_inst, OnnxGraphNode): | |||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||
| node_scope_name[node_name] = node_name_with_scope | |||
| node_name = node_name_with_scope | |||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||
| tree.insert(node_inst, node_name) | |||
| if node_scope_name: | |||
| return tree, node_scope_name | |||
| return tree | |||
| @@ -1,796 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define hierarchical tree.""" | |||
| import os | |||
| import stat | |||
| from copy import deepcopy | |||
| from typing import NoReturn, Union | |||
| from queue import Queue | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from treelib import Tree, Node | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||
| from ..common.utils import is_converted, save_code_file_and_report | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS | |||
| from ..constant import CodeFormatConfig | |||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| from ..report_generator import ReportGenerator | |||
| from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError | |||
| class HierarchicalTree(Tree): | |||
| """Define hierarchical tree.""" | |||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||
| modes = stat.S_IRUSR | stat.S_IWUSR | |||
| modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | |||
| _root_created = False | |||
| ROOT_LEVEL = 0 | |||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| def __init__(self): | |||
| super(HierarchicalTree, self).__init__() | |||
| self._hierarchical_order = dict() | |||
| # Manage mapping of unique key and module name. | |||
| self._merged_module = dict() | |||
| # Manage mapping of unique key and module args. | |||
| self._merged_module_args = dict() | |||
| # Record creation of module with unique key. | |||
| self._created_module = dict() | |||
| # Manage module name to used. | |||
| self._module_mgr = ModuleNameMgr() | |||
| # Manage variable name in a module. | |||
| self._vars_mgr_in_module = dict() | |||
| self._module_vars = dict() | |||
| # scope name mapping record for easy node searching | |||
| self._scope_name_map = dict() | |||
| self.code_fragment_recorder = dict() | |||
| @property | |||
| def tree_identifier(self): | |||
| """ | |||
| Return identifier of tree. | |||
| Returns: | |||
| tree, id of tree. | |||
| """ | |||
| return self.identifier | |||
| def get_node(self, nid): | |||
| """Override get_node method to support tf ver. generated scope.""" | |||
| if nid is None or not self.contains(nid): | |||
| if self._scope_name_map and nid in self._scope_name_map: | |||
| nid = self._scope_name_map.get(nid) | |||
| else: | |||
| return None | |||
| return self._nodes[nid] | |||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): | |||
| """ | |||
| Insert node into hierarchical tree. | |||
| Args: | |||
| node_name (str): Node name. | |||
| node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | |||
| """ | |||
| scopes = node_name.split(SEPARATOR_IN_SCOPE) | |||
| for idx, scope in enumerate(scopes): | |||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | |||
| identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | |||
| try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | |||
| if parent else scope | |||
| if self.contains(try_parent): | |||
| # Whether current node existed. | |||
| parent = try_parent | |||
| if not parent and not self._root_created: | |||
| # If no root node, then create it and mark it. | |||
| parent = None | |||
| self._root_created = True | |||
| elif not parent and self._root_created: | |||
| # Already have root node, skip it. | |||
| continue | |||
| if not self.contains(identifier): | |||
| # Insert node into tree. | |||
| if isinstance(node, OnnxGraphNode): | |||
| tgt_node = node if idx == len( | |||
| scopes) - 1 else OnnxGraphNode() | |||
| else: | |||
| tgt_node = node if idx == len( | |||
| scopes) - 1 else PyTorchGraphNode() | |||
| tgt_node.successor_nodes = node.successor_nodes | |||
| tgt_node.precursor_nodes = node.precursor_nodes | |||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | |||
| else NodeType.MODULE).value | |||
| tgt_node.variable_name = self._get_var_name(identifier) | |||
| self.create_node( | |||
| tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], | |||
| identifier=identifier, | |||
| parent=parent, | |||
| data=tgt_node | |||
| ) | |||
| def remove(self, node: Node, keep_sub=False): | |||
| """ | |||
| Remove node into hierarchical tree. | |||
| Args: | |||
| node (Node): Node to be removed. | |||
| keep_sub (bool): Whether keep sub-tree. | |||
| """ | |||
| if not keep_sub: | |||
| self.remove_node(node.identifier) | |||
| return | |||
| def shrink(self, node: Node): | |||
| """ | |||
| Shrink sub-tree into one node. | |||
| Use child node to replace its ancestor. | |||
| Args: | |||
| node (Node): List of nodes to be merged. | |||
| """ | |||
| node_name = node.identifier | |||
| parent_node = self[node.predecessor(self.tree_identifier)] | |||
| # Keep successors of parent. | |||
| brothers = deepcopy(parent_node.successors(self.tree_identifier)) | |||
| # Because shrink occurs when node has only one child, | |||
| # so we take index-0. | |||
| child = node.successors(self.tree_identifier)[0] | |||
| self.move_node(source=child, | |||
| destination=node.predecessor(self.tree_identifier)) | |||
| self.remove(node) | |||
| brothers[brothers.index(node_name)] = child | |||
| parent_node.set_successors(brothers, tree_id=self.tree_identifier) | |||
| def save_source_files(self, out_folder: str, mapper: Mapper, | |||
| model_name: str, | |||
| report_folder: str = None, | |||
| scope_name_map: dict = None) -> NoReturn: | |||
| """ | |||
| Save source codes to target folder. | |||
| Args: | |||
| report_folder (str): Report folder. | |||
| mapper (Mapper): Mapper of third party framework and mindspore. | |||
| model_name(str): Name of Converted model. | |||
| out_folder (str): Output folder. | |||
| scope_name_map(str): Scope name map of tensorflow. | |||
| """ | |||
| if scope_name_map: | |||
| self._scope_name_map = scope_name_map | |||
| try: | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e: | |||
| log.error("Error occur when generating codes.") | |||
| raise e | |||
| save_code_file_and_report(model_name, code_fragments, out_folder, report_folder) | |||
| def _preprocess_node_args(self, node, module_key): | |||
| """ | |||
| Remove unused args. | |||
| Args: | |||
| node (Node): Node instance. | |||
| module_key (str): Nodule key. | |||
| Returns: | |||
| Node, node. | |||
| """ | |||
| if module_key in self._merged_module_args: | |||
| node = self._clear_unused_args( | |||
| node, self._merged_module_args[module_key]) | |||
| else: | |||
| node.data.clear_args_of_declaration() | |||
| return node | |||
| def _postprocess_node_args(self, node, precursor_module_key): | |||
| """ | |||
| Post process args in node. | |||
| Args: | |||
| node (Node): Node instance. | |||
| precursor_module_key (str): Parent node module name. | |||
| Returns: | |||
| Node, node. | |||
| """ | |||
| if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, | |||
| NodeType.FUNC.value}: | |||
| # If current node is class or function, then | |||
| # remove unused args in __init__. | |||
| cur_module_key = node.data.hash_key or self.hash_key(node) | |||
| if cur_module_key in self._merged_module_args: | |||
| node = self._clear_unused_args(node, | |||
| self._merged_module_args[cur_module_key]) | |||
| # `self._merged_module_args` records formal args. | |||
| # We need to replace actual args. | |||
| if precursor_module_key in self._merged_module_args: | |||
| # If parent node is in `_merged_module_args`, then | |||
| # replace current node args with arg name declared | |||
| # in _merged_module_args. | |||
| for arg in node.data.args_in_code.keys(): | |||
| if arg in self._merged_module_args[precursor_module_key]: | |||
| node.data.replace_with_arg(arg, arg) | |||
| return node | |||
| def _clear_unused_args(self, node, used_args): | |||
| """ | |||
| Clear unused args. | |||
| Args: | |||
| node (Node): Node. | |||
| used_args (list): Args list. | |||
| Returns: | |||
| Node, node instance. | |||
| """ | |||
| args_in_code = list(node.data.args_in_code.keys()) | |||
| for arg in args_in_code: | |||
| ori_arg = arg.replace( | |||
| f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" | |||
| ) | |||
| if ori_arg not in used_args: | |||
| node.data.args_in_code.pop(arg) | |||
| return node | |||
| @ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.") | |||
| @ReportGenerationError.check_except("Not find valid operators in converted script.") | |||
| def _generate_codes(self, mapper): | |||
| """ | |||
| Generate code files. | |||
| - 1. Generate args. | |||
| - 2. Merge module. | |||
| - 3. Pre-process node args. | |||
| - 4. Post-process child node args. | |||
| - 5. Generate class/func code. | |||
| - 6. Merge code snippets. | |||
| Args: | |||
| mapper (Mapper): Mapper of third party operation and mindspore. | |||
| Returns: | |||
| Dict, codes. | |||
| """ | |||
| code_blocks = [get_imported_module()] | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| node_collection = self._hierarchical_order[depth] | |||
| for node_name in node_collection: | |||
| # Traverse nodes in topological order. | |||
| node = self.get_node(node_name) | |||
| # 1. Generate args for each node in this level. | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| self._create_module_args_and_vars(node, mapper) | |||
| if depth == depths[-1]: | |||
| self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") | |||
| # Module merging based on all nodes. | |||
| self._module_merging() | |||
| for depth in depths: | |||
| node_collection = self._hierarchical_order[depth] | |||
| snippets = set() | |||
| for node_name in node_collection: | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||
| continue | |||
| # Generate hash key for node. | |||
| module_key = nd_inst.data.hash_key | |||
| # Get code generation func. | |||
| func, node_type = self._fetch_func_and_type(nd_inst) | |||
| if module_key in self._created_module: | |||
| # If the module has already been created, | |||
| # then assign the created module name to current node, | |||
| # and delete unused args. | |||
| module_name = self._created_module[module_key] | |||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||
| self._preprocess_node_args(nd_inst, module_key) | |||
| continue | |||
| module_name = nd_inst.tag | |||
| if node_type == NodeType.CLASS.value: | |||
| module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||
| # After node_type and module_name is frozen, | |||
| # then it's unchangeable. | |||
| module_name = self._module_mgr.get_name(module_name) | |||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||
| # 3. Pre-process node args. | |||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||
| # 4. Post-process child node args. | |||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||
| # 5. Generate code. | |||
| snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) | |||
| code_blocks.extend(snippets) | |||
| if self._scope_name_map: # from tf. conversion | |||
| c_blocks = [] | |||
| for s in code_blocks: | |||
| s = s.replace('$', '') | |||
| c_blocks.append(s) | |||
| code_blocks = c_blocks | |||
| formatted_code, _ = FormatCode("".join(code_blocks), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| report_generator = ReportGenerator() | |||
| report = report_generator.gen_report(formatted_code) | |||
| return {"model": (formatted_code, report)} | |||
| def _fetch_func_and_type(self, node) -> Union[object, str]: | |||
| """ | |||
| Generate code snippet. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| Union[object, str], code snippet func. | |||
| """ | |||
| def _is_func(): | |||
| """ | |||
| The correct thought is to check whether have more than one | |||
| path in this block. | |||
| """ | |||
| nonlocal node | |||
| if node.predecessor(self.tree_identifier) is None: | |||
| return False | |||
| tgt_type = {NodeType.MODULE.value, | |||
| NodeType.FUNC.value, NodeType.CLASS.value} | |||
| md_type_lst = [self.get_node(child).data.node_type | |||
| for child in node.successors(self.tree_identifier)] | |||
| diff_set = set(md_type_lst) - tgt_type | |||
| return not diff_set | |||
| if _is_func(): | |||
| return self._generate_func_snippet, NodeType.FUNC.value | |||
| return self._generate_class_snippet, NodeType.CLASS.value | |||
| def _generate_func_snippet(self, node, func_name, func_key): | |||
| """ | |||
| Generate function snippet. | |||
| Args: | |||
| node (Node): Node inst. | |||
| Returns: | |||
| str, code snippet. | |||
| """ | |||
| definition = "" | |||
| if func_key.lower() in self._merged_module_args and \ | |||
| self._merged_module_args[func_key.lower()]: | |||
| definition = ", ".join(self._merged_module_args[func_key.lower()]) | |||
| module_list = [] | |||
| for node_name in node.successors(self.tree_identifier): | |||
| c_nd = self.get_node(node_name) | |||
| operator = self.code_fragment_recorder[c_nd.identifier].operation | |||
| if c_nd.data.node_type != NodeType.OPERATION.value: | |||
| hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | |||
| if hash_key in self._created_module: | |||
| operator = self._created_module[hash_key] | |||
| args = c_nd.data.args_in_code | |||
| if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( | |||
| self.code_fragment_recorder[c_nd.identifier].operation): | |||
| args.update({"input_shape": c_nd.data.input_shape, | |||
| "output_shape": c_nd.data.output_shape}) | |||
| # Generate code statement. | |||
| expr = ", ".join( | |||
| [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" | |||
| for k, v in args.items()] | |||
| ) | |||
| code_line = f"{operator}({expr})" | |||
| module_list.append(code_line) | |||
| body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | |||
| snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \ | |||
| f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \ | |||
| f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \ | |||
| f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)" | |||
| definition = f"def {func_name}({definition}):{NEW_LINE}" | |||
| # Mark the structure has been created. | |||
| self._created_module[func_key.lower()] = func_name | |||
| return f"{definition}{snippet}{NEW_LINE * 3}" | |||
| def _generate_class_snippet(self, node, class_name, class_key): | |||
| """ | |||
| Generate class-type code snippet. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| str, code snippet. | |||
| """ | |||
| super_call = f"super({class_name}, self).__init__()" | |||
| if class_key.lower() in self._merged_module_args and \ | |||
| self._merged_module_args[class_key.lower()]: | |||
| args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | |||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | |||
| f"{args}):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| else: | |||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| init_block = [] | |||
| construct_block = [] | |||
| for idx, node_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| # Generate code statement. | |||
| init, construct = self._generate_stat(nd_inst, node, idx) | |||
| # support multiple construct and init block returns: | |||
| if isinstance(construct, list): | |||
| construct_block += construct | |||
| else: | |||
| construct_block.append(construct) | |||
| if isinstance(init, list): | |||
| init_block += init | |||
| else: | |||
| init_block.append(init) | |||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | |||
| csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | |||
| csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | |||
| cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}" | |||
| # Mark the structure has been created. | |||
| self._created_module[class_key.lower()] = class_name | |||
| return f"{cls_definition}" \ | |||
| f"{class_init}" \ | |||
| f"{init_body}{NEW_LINE}" \ | |||
| f"{class_construct}" \ | |||
| f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}" | |||
| def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| cur_nd_inst (Node): Current node instance. | |||
| pre_nd_inst (Node): Precursor node instance. | |||
| idx (int): Index of cur node. | |||
| Returns: | |||
| Tuple[str, str], declare in init and call in construct. | |||
| """ | |||
| ipt_args_in_construct = "x" | |||
| opt_arg_in_construct = ["output"] | |||
| if idx != 0: | |||
| if cur_nd_inst.data.is_in_multi_opt_graph: | |||
| ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst) | |||
| else: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||
| # Set opt variable name. | |||
| if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph: | |||
| opt_arg_in_construct = [ | |||
| f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||
| ] | |||
| else: | |||
| opt_arg_in_construct = [ | |||
| f"opt_{var_name}" | |||
| for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name | |||
| ] | |||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | |||
| variable_name=self.code_fragment_recorder[ | |||
| cur_nd_inst.identifier].declared_var_name, | |||
| output_var=opt_arg_in_construct, | |||
| code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) | |||
| return declare, call | |||
| @staticmethod | |||
| def _get_var_name(s): | |||
| """ | |||
| Get variable name using scope name. | |||
| Args: | |||
| s (str): String. | |||
| Returns: | |||
| str, variable name. | |||
| """ | |||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| def _get_current_ipt_var(self, cur_nd): | |||
| """" | |||
| Get current input variable name from node_id. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| Returns: | |||
| str, needed var names. | |||
| """ | |||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||
| while True: | |||
| p_nd = cur_nd.successors(self.tree_identifier) | |||
| if not p_nd: | |||
| break | |||
| cur_nd = self.get_node(p_nd[0]) | |||
| ipt_lst_raw = [] | |||
| for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs: | |||
| ipt_lst_raw.append(f"{operation_input}") | |||
| opt_var_names_p_nds = set() | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if p_nd.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| opt_var_names_p_nd = set(p_nd.data.opt_var_names) | |||
| opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd) | |||
| ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)] | |||
| return ", ".join(ipt_lst) | |||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | |||
| """ | |||
| Find all input variable names. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| pre_nd (Node): Precursor node. | |||
| Returns: | |||
| list, needed var names list. | |||
| """ | |||
| ipt_lst = [] | |||
| if cur_nd.tag in NO_CONVERTED_OPERATORS: | |||
| return ipt_lst | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| while True: | |||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||
| ipt_lst.append( | |||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||
| ) | |||
| break | |||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||
| if not pre_nd_name: | |||
| ipt_lst.append("x") | |||
| break | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| ipt_lst.append( | |||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||
| ) | |||
| return ipt_lst | |||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||
| """ | |||
| Get needed input variable names. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| pre_nd (Node): Precursor node. | |||
| Returns: | |||
| str, needed var names. | |||
| """ | |||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||
| while True: | |||
| p_nd = cur_nd.successors(self.tree_identifier) | |||
| if not p_nd: | |||
| break | |||
| cur_nd = self.get_node(p_nd[0]) | |||
| return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | |||
| def hash_key(self, node): | |||
| """ | |||
| Generate hash key for each node. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| str, hash key. | |||
| """ | |||
| scsr_topo_order = [] | |||
| for s in node.successors(self.tree_identifier): | |||
| cur_nd = self.get_node(s) | |||
| if cur_nd.data.node_type in {NodeType.MODULE.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.CLASS.value}: | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(f"({cur_nd.data.hash_key})") | |||
| continue | |||
| raise ValueError("Current node doesn't have hash key.") | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(cur_nd.data.hash_key) | |||
| continue | |||
| unique_key = "->".join(scsr_topo_order) | |||
| node.data.hash_key = unique_key | |||
| return unique_key | |||
| def _module_merging(self): | |||
| """Generate sub-module and corresponding params.""" | |||
| merged_module_args = dict() | |||
| for module_key, module_args in self._merged_module.items(): | |||
| if module_key not in merged_module_args: | |||
| merged_module_args[module_key] = [] | |||
| # Take first element's args as base. | |||
| keys = module_args[0].keys() | |||
| for key in keys: | |||
| for i in range(1, len(module_args)): | |||
| if key in module_args[i] and module_args[0][key] != module_args[i][key]: | |||
| merged_module_args[module_key].append(key) | |||
| break | |||
| if key not in module_args[i]: | |||
| merged_module_args[module_key].append(key) | |||
| break | |||
| self._merged_module_args.update(merged_module_args) | |||
| def _create_module_args_and_vars(self, node, mapper): | |||
| """ | |||
| Create module args and variables in current node. | |||
| Args: | |||
| node (Node): Node on tree. | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| # All args and value pair in current node module. | |||
| module_args = dict() | |||
| module_key = self.hash_key(node) | |||
| created = False | |||
| if module_key not in self._vars_mgr_in_module: | |||
| self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR | |||
| self._module_vars[module_key] = [] | |||
| else: | |||
| created = True | |||
| # Sub-modules in the module could have arg name conflicts. | |||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(successor_name) | |||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| # Generation of params must behind variable assigment. | |||
| if created: | |||
| variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| variable_name = nd_inst.data.op_name or nd_inst.tag | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | |||
| code_fragment.declared_var_name = variable_name | |||
| code_fragment.output_var_name = nd_inst.data.opt_var_names | |||
| code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names) | |||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| if not created: | |||
| self._module_vars[module_key].append(variable_name) | |||
| node.data.args_in_code = module_args | |||
| # Collect module args of `module_key`. | |||
| if module_key not in self._merged_module: | |||
| self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] | |||
| else: | |||
| self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) | |||
| @staticmethod | |||
| def _create_operation_args(node, mapper): | |||
| """ | |||
| Create operation args. | |||
| Args: | |||
| node (Node): Node on tree. | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| node.data.param_transform(mapper) | |||
| def update_hierarchical_order(self) -> NoReturn: | |||
| """ | |||
| Update hierarchical order. | |||
| """ | |||
| hierarchical_order = dict() | |||
| queue = Queue() | |||
| queue.put(item=(self.root, self.ROOT_LEVEL), block=False) | |||
| while not queue.empty(): | |||
| node_name, cur_level = queue.get(block=False) | |||
| node_inst = self[node_name] | |||
| if cur_level not in hierarchical_order: | |||
| hierarchical_order[cur_level] = [] | |||
| hierarchical_order[cur_level].append(node_name) | |||
| for successor_name in node_inst.successors(self.tree_identifier): | |||
| queue.put(item=(successor_name, cur_level + 1), block=False) | |||
| self._hierarchical_order = hierarchical_order | |||
| def sub_graph_merging(self) -> NoReturn: | |||
| """Shrink the module has only one child.""" | |||
| self.update_hierarchical_order() | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| for node_name in self._hierarchical_order[depth]: | |||
| node_inst = self[node_name] | |||
| # If the node type is module and has only one child, | |||
| # then merge it with its child. | |||
| if node_inst.data.node_type == NodeType.MODULE.value and \ | |||
| len(node_inst.successors(self.tree_identifier)) == 1: | |||
| self.shrink(node_inst) | |||
| def _adjust_structure(self) -> NoReturn: | |||
| """Adjust tree structure to generate source code.""" | |||
| self.sub_graph_merging() | |||
| self.update_hierarchical_order() | |||
| @@ -171,3 +171,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| def _find_val_by_index(loc_index, values_dict): | |||
| """Find value by location index of values_dict.""" | |||
| result = None | |||
| for idx, dict_val in enumerate(values_dict.values()): | |||
| if idx == loc_index: | |||
| result = dict_val | |||
| break | |||
| return result | |||
| @@ -42,7 +42,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| """Convert params from PyTorch to MindSpore""" | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| weight = weights['weight'].numpy() | |||
| weight = weights['weight'] | |||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | |||
| if isinstance(params['dilations'], list): | |||
| dilation = tuple(params['dilations']) | |||
| @@ -130,7 +130,6 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| dim = len(kernel_size) | |||
| return f"nn.Conv{dim}d" | |||
| weight = weight.numpy() | |||
| dim = weight.ndim - 2 | |||
| return f"nn.Conv{dim}d" | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| @@ -27,8 +28,11 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs['weights'] | |||
| has_bias = bool('bias' in weights) | |||
| weight = weights['weight'].numpy().transpose() | |||
| weight_index = 0 | |||
| bias_index = 1 | |||
| bias = DenseMapper._find_val_by_index(bias_index, weights) | |||
| has_bias = isinstance(bias, np.ndarray) | |||
| weight = DenseMapper._find_val_by_index(weight_index, weights).transpose() | |||
| in_channels, out_channels = weight.shape | |||
| return { | |||
| 'in_channels': in_channels, | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| @@ -47,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| params = kwargs.get("params") | |||
| mode = params.get('mode', 'constant') | |||
| mode = convert_bytes_string_to_string(params.get('mode', 'constant')) | |||
| pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() | |||
| if mode == 'constant' and params.get('value') is None: | |||
| if params.get('pads') or weights: | |||
| @@ -36,7 +36,7 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | |||
| transformed_params["stride"] = tuple(params['strides']) | |||
| if "pads" in params: | |||
| if sum(params['pads']) == 0: | |||
| if sum(params['pads']) == 0 and not params.get('ceil_mode', None): | |||
| pad_mode = '\"valid\"' | |||
| else: | |||
| pad_mode = '\"same\"' | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -39,14 +39,9 @@ class GraphFactory: | |||
| Returns: | |||
| Graph, graph instance. | |||
| """ | |||
| if all([input_nodes, output_nodes]): | |||
| onnx_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||
| pytorch_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') | |||
| pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') | |||
| return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) | |||
| onnx_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||
| @@ -264,7 +264,7 @@ class Graph(BaseGraph, abc.ABC): | |||
| Returns: | |||
| cls, graph instance. | |||
| """ | |||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | |||
| src_graph = cls.load_graph(graph_path=model_path, sample_shape=sample_shape, **kwargs) | |||
| ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None | |||
| if ckpt is not None: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,12 +13,14 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX graph.""" | |||
| from importlib import import_module | |||
| from typing import Dict, NoReturn | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| from .pytorch_graph_parser import PyTorchGraphParser | |||
| from .tf_graph_parser import TFGraphParser | |||
| from .onnx_utils import OnnxDataLoader | |||
| @@ -151,7 +153,7 @@ class OnnxGraph(Graph): | |||
| input_shape (tuple): Input shape. | |||
| """ | |||
| input_node = InputNode(input_shape) | |||
| input_node_name = self._raw_input_nodes.replace(":0", "") | |||
| input_node_name = self._raw_input_nodes | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||
| @@ -196,7 +198,18 @@ class OnnxGraph(Graph): | |||
| """ | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| onnx_model = TFGraphParser.parse(graph_path, | |||
| input_nodes=tf_input_nodes, | |||
| output_nodes=tf_output_nodes) | |||
| if graph_path.endswith('.pb'): | |||
| onnx_model = TFGraphParser.parse(graph_path, | |||
| input_nodes=tf_input_nodes, | |||
| output_nodes=tf_output_nodes) | |||
| elif graph_path.endswith('.onnx'): | |||
| onnx = import_module('onnx') | |||
| onnx_model = onnx.load(graph_path) | |||
| optimizer = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer') | |||
| onnx_simplify = getattr(optimizer, 'OnnxSimplify')() | |||
| onnx_model = onnx_simplify.run_onnx_simplify(onnx_model, kwargs['sample_shape']) | |||
| else: | |||
| onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | |||
| return onnx_model | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -112,10 +112,10 @@ class OnnxTensor: | |||
| def to_array(self): | |||
| """Convert the tensor value from binary to np array.""" | |||
| onnx = import_module("onnx") | |||
| numpy_helper = import_module("onnx.numpy_helper") | |||
| # Convert binary data to np.array | |||
| if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | |||
| return onnx.numpy_helper.to_array(self.raw_tensor) | |||
| return numpy_helper.to_array(self.raw_tensor) | |||
| return self.raw_tensor | |||
| @@ -383,15 +383,24 @@ class OnnxDataLoader: | |||
| """Parse each onnx nodes in the model.""" | |||
| nodes_topo_idx = [] | |||
| for idx, node in enumerate(self.nodes): | |||
| if not node.name: | |||
| node.name = "_".join(node.output) | |||
| n = OnnxNode(node) | |||
| self._nodes_dict[n.name] = n | |||
| nodes_topo_idx.append((idx, n.name)) | |||
| if len(node.output) > 1: | |||
| raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | |||
| self.output_name_to_node_name[node.output[0]] = node.name | |||
| for ipt_nd in node.input: | |||
| if ipt_nd not in self.output_name_to_node_name: | |||
| if self._global_context.onnx_node_inputs.get(n.name): | |||
| self._global_context.onnx_node_inputs[n.name].append(ipt_nd) | |||
| else: | |||
| self._global_context.onnx_node_inputs[n.name] = [ipt_nd] | |||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||
| node_inputs = [i.replace(":0", "") for i in node.input] | |||
| self._global_context.onnx_node_inputs[n.name] = node_inputs | |||
| self._global_context.onnx_nodes_collection = self._nodes_dict | |||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | |||
| @@ -449,7 +458,11 @@ class OnnxDataLoader: | |||
| input_node = self.get_node(input_node_name) | |||
| node.precursor_onnx_node_dict[input_node_name] = input_node | |||
| input_node.successor_onnx_node_dict[node_name] = node | |||
| continue | |||
| if self._global_context.onnx_node_inputs.get(node.name): | |||
| self._global_context.onnx_node_inputs[node.name].append(input_node_name) | |||
| else: | |||
| self._global_context.onnx_node_inputs[node.name] = [input_node_name] | |||
| def initialize(self): | |||
| """Initialize the OnnxDataLoader.""" | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX optimizer operations.""" | |||
| import copy | |||
| from importlib import import_module | |||
| import numpy as np | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| class OnnxSimplify: | |||
| """To simplify onnx model.""" | |||
| def __init__(self): | |||
| self._onnx_model = None | |||
| self._constant_nodes = list() | |||
| self._outputs_infer = dict() | |||
| def run_onnx_simplify(self, onnx_model, sample_shape): | |||
| """ | |||
| Run to simplify onnx model. | |||
| Args: | |||
| onnx_model (onnx.ModelProto): Onnx Model. | |||
| sample_shape (tuple): Sample shape of input. | |||
| """ | |||
| self._onnx_model = onnx_model | |||
| self._optimizer() | |||
| self._get_constant_nodes() | |||
| self._onnx_infer(sample_shape) | |||
| self._replace_constant_nodes() | |||
| self._optimizer() | |||
| return self._onnx_model | |||
| def _optimizer(self): | |||
| """Run optimizer from onnx to eliminate constant nodes.""" | |||
| onnxoptimizer = import_module('onnxoptimizer') | |||
| optimizers_list = [ | |||
| 'eliminate_deadend', | |||
| 'eliminate_nop_dropout', | |||
| 'eliminate_nop_cast', | |||
| 'eliminate_nop_monotone_argmax', | |||
| 'eliminate_nop_pad', | |||
| 'extract_constant_to_initializer', | |||
| 'eliminate_unused_initializer', | |||
| 'eliminate_nop_transpose', | |||
| 'eliminate_identity', | |||
| 'fuse_add_bias_into_conv', | |||
| 'fuse_consecutive_concats', | |||
| 'fuse_consecutive_log_softmax', | |||
| 'fuse_consecutive_reduce_unsqueeze', | |||
| 'fuse_consecutive_squeezes', | |||
| 'fuse_consecutive_transposes', | |||
| 'fuse_matmul_add_bias_into_gemm', | |||
| 'fuse_pad_into_conv', | |||
| 'fuse_transpose_into_gemm' | |||
| ] | |||
| input_num = len(self._onnx_model.graph.input) | |||
| onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) | |||
| if self._onnx_model.ir_version > 3: | |||
| del onnx_model_optimized.graph.input[input_num:] | |||
| self._onnx_model = onnx_model_optimized | |||
| def _get_constant_nodes(self): | |||
| """Get constant nodes.""" | |||
| const_nodes = list() | |||
| const_tensors = [tensor_init.name for tensor_init in self._onnx_model.graph.initializer] | |||
| const_tensors.append([node.output[0] | |||
| for node in self._onnx_model.graph.node if node.op_type == 'Constant']) | |||
| for node in self._onnx_model.graph.node: | |||
| if node.op_type == 'Shape' or all([input_node in const_tensors for input_node in node.input]): | |||
| const_nodes.append(node) | |||
| const_tensors.extend(node.output) | |||
| self._constant_nodes = copy.deepcopy(const_nodes) | |||
| def _onnx_infer(self, infer_inputs_shape): | |||
| """ | |||
| Run onnx inference to get outputs of constant nodes. | |||
| Args: | |||
| infer_inputs_shape (tuple): Input shape for running inference. | |||
| """ | |||
| input_onnx = self._onnx_model.graph.input[0] | |||
| input_onnx_name = input_onnx.name | |||
| feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} | |||
| output_nodes_name = list() | |||
| for node in self._constant_nodes: | |||
| output_nodes_name.extend(node.output) | |||
| self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name) | |||
| def _replace_constant_nodes(self): | |||
| """Replace constant nodes to nodes with op_type 'Constant'.""" | |||
| onnx = import_module('onnx') | |||
| np_helper = import_module('onnx.numpy_helper') | |||
| for i, node in enumerate(self._onnx_model.graph.node): | |||
| if node in self._constant_nodes: | |||
| for output in node.output: | |||
| new_attr = onnx.helper.make_attribute( | |||
| 'value', | |||
| np_helper.from_array(self._outputs_infer[output], name=output) | |||
| ) | |||
| new_node = onnx.helper.make_node( | |||
| op_type='Constant', | |||
| inputs=list(), | |||
| outputs=[output], | |||
| name='_'.join(('node', output)) | |||
| ) | |||
| new_node.attribute.extend([new_attr]) | |||
| self._insert_node(self._onnx_model.graph.node, i + 1, new_node) | |||
| del self._onnx_model.graph.node[i] | |||
| @staticmethod | |||
| def _insert_node(repeated_container, index, node): | |||
| """Insert node into onnx model.""" | |||
| repeated_container.extend([repeated_container[-1]]) | |||
| for i in reversed(range(index + 1, len(repeated_container) - 1)): | |||
| repeated_container[i].CopyFrom(repeated_container[i - 1]) | |||
| repeated_container[index].CopyFrom(node) | |||
| @@ -1,691 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph.""" | |||
| import os | |||
| import re | |||
| import warnings | |||
| from copy import deepcopy | |||
| from importlib import import_module | |||
| from typing import Dict, NoReturn | |||
| import numpy as np | |||
| from mindinsight.conf import settings | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .pytorch_graph_parser import PyTorchGraphParser | |||
| from .torch_utils import set_opset_version | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ | |||
| MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| from ...common.exceptions import ModelNotSupportError | |||
| NONE_SCOPE_OP = { | |||
| "onnx::Add": "Add", | |||
| "onnx::Flatten": "Flatten", | |||
| "onnx::Concat": "Concat", | |||
| "onnx::Squeeze": "Squeeze", | |||
| "onnx::Unsqueeze": "Unsqueeze", | |||
| "onnx::Split": "Split", | |||
| "onnx::Reshape": "Reshape", | |||
| "onnx::Transpose": "Transpose", | |||
| "onnx::Constant": "Constant", | |||
| "onnx::ReduceMean": "ReduceMean", | |||
| "onnx::Resize": "Resize", | |||
| "onnx::Pad": "Pad" | |||
| } | |||
| CONSTANT_NODES_PATTERN = { | |||
| "onnx::Resize": [ | |||
| 'onnx::Concat', | |||
| 'onnx::Slice', | |||
| 'onnx::Cast', | |||
| 'onnx::Concat', | |||
| 'onnx::Unsqueeze', | |||
| 'onnx::Floor', | |||
| 'onnx::Mul', | |||
| 'onnx::Cast', | |||
| 'onnx::Gather', | |||
| 'onnx::Shape' | |||
| ], | |||
| "onnx::Pad": [ | |||
| 'onnx::Cast', | |||
| 'onnx::Concat', | |||
| 'onnx::ConstantOfShape', | |||
| 'onnx::Sub', | |||
| 'onnx::Mul', | |||
| 'onnx::Div', | |||
| 'onnx::Gather', | |||
| 'onnx::Shape', | |||
| 'onnx::Unsqueeze', | |||
| 'onnx::Slice', | |||
| 'onnx::Reshape', | |||
| 'onnx::Transpose' | |||
| ], | |||
| "onnx::Constant": list() | |||
| } | |||
| def normalize_scope_name(node, scope_name_dict): | |||
| """ | |||
| Rename scope name into uniform. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||
| Returns: | |||
| str, normalized scope name. | |||
| """ | |||
| global NONE_SCOPE_OP | |||
| scope_name = node.scopeName() | |||
| if not scope_name: | |||
| name = [retrieve_scope_name(node, scope_name_dict)] | |||
| else: | |||
| name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||
| scopes = [] | |||
| for segment in name: | |||
| segment = segment.split(LINK_IN_SCOPE)[0] | |||
| left = segment.find(LEFT_BUCKET) | |||
| right = segment.find(RIGHT_BUCKET) | |||
| if left != -1: | |||
| if segment[left + 1: right].isdigit(): | |||
| scopes.append(f"{segment[:left]}_{segment[left + 1: right]}") | |||
| else: | |||
| scopes.append(segment[left + 1: right]) | |||
| else: | |||
| scopes.append(segment) | |||
| if node.kind() in NONE_SCOPE_OP.keys(): | |||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | |||
| scopes = [s for s in scopes if s] | |||
| node_id = PyTorchGraph.get_node_id(node) | |||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}" | |||
| def retrieve_scope_name(node, scope_name_dict): | |||
| """ | |||
| Retrieve scope name from input nodes. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||
| Return: | |||
| str: Scope name. | |||
| """ | |||
| node_content = \ | |||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]) | |||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||
| scope_name_ipt_nodes = list() | |||
| for node_input in node_inputs: | |||
| if not scope_name_dict.get(node_input, None): | |||
| continue | |||
| scope_name_ipt_nodes.append(scope_name_dict[node_input]) | |||
| scope_name_split = list() | |||
| for idx, _ in enumerate(scope_name_ipt_nodes): | |||
| if not scope_name_split: | |||
| scope_name_split = scope_name_ipt_nodes[idx] | |||
| else: | |||
| scope_name_split = [ | |||
| sub_scope_name | |||
| for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx] | |||
| ] | |||
| scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split) | |||
| return scope_name | |||
| class PyTorchGraph(Graph): | |||
| """ | |||
| Define PyTorch graph. | |||
| Args: | |||
| model (Module): PyTorch model. | |||
| sample_shape (tuple): Input shape of the model. | |||
| """ | |||
| def __init__(self, model, sample_shape: tuple): | |||
| super(PyTorchGraph, self).__init__(model=model) | |||
| from .torch_utils import unique_state_dict | |||
| self._params_dict = unique_state_dict(model) | |||
| self._original_shape = list() | |||
| self._nodes = list() | |||
| self._constant_nodes = list() | |||
| self._dynamic_nodes = list() | |||
| self._has_eliminated_nodes = False | |||
| self._file_graph_onnx = os.path.join( | |||
| settings.WORKSPACE, 'log/mindconverter/' | |||
| ) | |||
| self.build(sample_shape) | |||
| @staticmethod | |||
| def _check_input_shape(input_shape): | |||
| """ | |||
| Check input shape. | |||
| Args: | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| if not input_shape: | |||
| err_msg = "`input_shape` can not be None." | |||
| log.error(err_msg) | |||
| raise ValueError(err_msg) | |||
| for item in input_shape: | |||
| if not isinstance(item, int): | |||
| err_msg = "Only support model with one input now, " \ | |||
| "and each shape value in `input_shape` should be int." | |||
| log.error(err_msg) | |||
| raise ValueError(err_msg) | |||
| @staticmethod | |||
| def _extract_shape(shape): | |||
| """ | |||
| Extract shape from string-type shape. | |||
| Args: | |||
| shape (str): Shape value in string-type. | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| if "," not in shape: | |||
| return [] | |||
| shape_arr = [] | |||
| for s in shape.split(","): | |||
| s = s.strip() | |||
| if not s: | |||
| return [] | |||
| if ":" in s: | |||
| s = s.split(":")[0] | |||
| s = s.replace("!", "") | |||
| if not s.isdigit(): | |||
| return [] | |||
| shape_arr.append(int(s)) | |||
| return shape_arr | |||
| def _trace_torch_graph(self, input_shape): | |||
| """ | |||
| Trace torch computational graph. | |||
| Args: | |||
| input_shape (tuple): Shape. | |||
| Returns: | |||
| object, pytorch graph. | |||
| """ | |||
| import torch | |||
| from torch.onnx import OperatorExportTypes | |||
| from .torch_utils import OverloadTorchModuleTemporarily | |||
| from .torch_utils import create_autograd_variable | |||
| from .torch_utils import onnx_tracer | |||
| warnings.simplefilter("ignore") | |||
| batched_sample = create_autograd_variable(torch.rand(*input_shape)) | |||
| try: | |||
| try: | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| return graph | |||
| except RuntimeError: | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| set_opset_version(ONNX_OPSET_VERSION) | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| return graph | |||
| except RuntimeError as error: | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| def build(self, input_shape): | |||
| """ | |||
| Build graph tree. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. | |||
| """ | |||
| self._check_input_shape(input_shape) | |||
| self._original_shape = input_shape | |||
| feed_forward_ipt_shape = tuple(input_shape) | |||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||
| nodes = list(graph.nodes()) | |||
| self._nodes = nodes | |||
| scope_name_dict = dict() | |||
| self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes) | |||
| for node in nodes: | |||
| output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs())) | |||
| if output_name in self._dynamic_nodes: | |||
| continue | |||
| node_name = normalize_scope_name(node, scope_name_dict) | |||
| scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ | |||
| = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) | |||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | |||
| output_shape_str = output_shape_str_list[1] | |||
| output_shape = self._extract_shape(output_shape_str) | |||
| weight_scope = '.'.join( | |||
| re.findall(r'\[([\w\d.]+)]', node.scopeName()) | |||
| ) | |||
| if self._constant_nodes: | |||
| node_weight = self._replace_constant_node(node) | |||
| else: | |||
| node_weight = {} | |||
| for scope, weight in self._params_dict.items(): | |||
| split_scope = scope.split('.') | |||
| if '.'.join(split_scope[:-1]) == weight_scope: | |||
| node_weight[split_scope[-1]] = weight | |||
| if not node_weight and node.kind() == 'onnx::Conv': | |||
| weight_names = list(self._params_dict.keys()) | |||
| node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||
| for node_input_name in node_input_names: | |||
| if int(node_input_name) > len(weight_names): | |||
| continue | |||
| weight = self._params_dict[weight_names[int(node_input_name) - 1]] | |||
| node_weight[weight_names[int(node_input_name) - 1]] = weight | |||
| self._shape_dict[node_name] = output_shape | |||
| self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight) | |||
| self._nodes_record[node_name] = node_name | |||
| for node_input in list(node.inputs()): | |||
| if self._extract_input_name(node_input) in self._constant_nodes: | |||
| continue | |||
| # Connect input node and src node. | |||
| nd_id = PyTorchGraph.get_node_id(node_input.node()) | |||
| nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ | |||
| node_input.node().scopeName() | |||
| if nd_id and nd_scope_name: | |||
| node_input_name = normalize_scope_name( | |||
| node_input.node(), scope_name_dict | |||
| ) | |||
| self.build_connection(node_input_name, node_name) | |||
| self._unmerge_multi_ipt_opt_script() | |||
| super(PyTorchGraph, self).build(input_shape=input_shape) | |||
| self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) | |||
| @staticmethod | |||
| def _extract_node_name(node): | |||
| """Extract node name for node.""" | |||
| result = re.match(r"\d+", str(node)) | |||
| if result: | |||
| return result.group(0) | |||
| return None | |||
| @staticmethod | |||
| def _extract_input_name(node_input): | |||
| """Extract node input name from node input.""" | |||
| node_input_name = str(node_input).split('defined in')[0].strip() | |||
| return node_input_name | |||
| def _get_constant_nodes(self, nodes): | |||
| """ | |||
| Get constant nodes to be eliminated. | |||
| Args: | |||
| nodes (Nodes): Nodes in torch._C.Graph. | |||
| Returns: | |||
| Union(dict, list), output of constant_input_node_name and dynamic nodes name. | |||
| """ | |||
| constant_input_nodes = list() | |||
| dynamic_nodes = list() | |||
| for node in nodes: | |||
| if node.kind() == 'onnx::Resize': | |||
| self._has_eliminated_nodes = True | |||
| constant_input_node, dynamic_node = self._generate_inputs_of(node) | |||
| constant_input_nodes += constant_input_node | |||
| dynamic_nodes += dynamic_node | |||
| outputs = dict() | |||
| if self._has_eliminated_nodes: | |||
| torch = import_module('torch') | |||
| device_target = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
| dump_input = torch.randn(*self._original_shape, device=device_target) | |||
| temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx, | |||
| '.graph_onnx.onnx')) | |||
| symbolic_helper = import_module('torch.onnx.symbolic_helper') | |||
| export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version') | |||
| try: | |||
| torch.onnx.export(self.model.to(device_target), dump_input, | |||
| temp_onnx_path, opset_version=export_onnx_opset_version) | |||
| outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape) | |||
| finally: | |||
| if os.path.exists(temp_onnx_path): | |||
| os.remove(temp_onnx_path) | |||
| return outputs, dynamic_nodes | |||
| def _generate_inputs_of(self, node): | |||
| """ | |||
| Generate inputs of certain node. | |||
| Args: | |||
| node (Node): Node of torch._C.Graph. | |||
| """ | |||
| pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None) | |||
| constant_input_nodes = list() | |||
| dynamic_nodes = list() | |||
| if not isinstance(pattern_op_lst, list): | |||
| return constant_input_nodes, dynamic_nodes | |||
| if not pattern_op_lst: | |||
| dynamic_nodes += self.get_node_id(node) | |||
| return constant_input_nodes, dynamic_nodes | |||
| node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||
| for node_input_name in node_inputs_name: | |||
| node_name_path = self._search_node_path(node_input_name, pattern_op_lst) | |||
| if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape': | |||
| constant_input_nodes.append(node_input_name) | |||
| dynamic_nodes += node_name_path | |||
| return constant_input_nodes, dynamic_nodes | |||
| def _search_node_path(self, node_name, pattern_op_lst): | |||
| """ | |||
| Search node path based on pattern_op_list. | |||
| Args: | |||
| node_name (str): Node name. | |||
| pattern_op_lst (list): Pattern list of certain operator. | |||
| Returns: | |||
| list[str]: node names in pattern. | |||
| """ | |||
| node_type_lst = list() | |||
| node_name_lst = list() | |||
| node = self._get_node_from_graph(node_name) | |||
| if node_name == MODEL_INPUT_NAME: | |||
| return node_name_lst | |||
| if node.kind() not in pattern_op_lst: | |||
| return node_name_lst | |||
| node_type_lst.append(node.kind()) | |||
| node_name_lst.append(node_name) | |||
| node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||
| for node_input_name in node_inputs_name: | |||
| node_name_lst += self._search_node_path(node_input_name, pattern_op_lst) | |||
| return node_name_lst | |||
| def _get_node_from_graph(self, node_name): | |||
| """Get torch._C.Node from torch._C.Graph.""" | |||
| for idx, node in enumerate(self._nodes): | |||
| node_id = ', '.join(self.get_node_id(node)) | |||
| if node_id == node_name: | |||
| return self._nodes[idx] | |||
| return None | |||
| @staticmethod | |||
| def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape): | |||
| """ | |||
| Infer onnx model to get outputs of inner nodes. | |||
| Args: | |||
| file_graph_onnx (str): File path of onnx. | |||
| infer_outputs (list): Outputs for infer. | |||
| infer_inputs_shape (list): Input shape for infer. | |||
| """ | |||
| onnx = import_module('onnx') | |||
| tensor_proto = getattr(onnx, 'TensorProto') | |||
| onnx_model = onnx.load(file_graph_onnx) | |||
| for onnx_node in onnx_model.graph.node: | |||
| if set(onnx_node.output).issubset(set(infer_outputs)): | |||
| onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output]) | |||
| input_onnx = onnx_model.graph.input[0] | |||
| node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type) | |||
| if node_type != 'FLOAT': | |||
| raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. " | |||
| f"Please report issue to us if extra input type is needed.") | |||
| input_onnx_name = input_onnx.name | |||
| feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} | |||
| outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs) | |||
| return outputs | |||
| def _replace_constant_node(self, node): | |||
| """Replace constant node.""" | |||
| node_weight = dict() | |||
| for node_input in list(node.inputs()): | |||
| node_input_name = self._extract_input_name(node_input) | |||
| if node_input_name in self._constant_nodes: | |||
| node_weight[node_input_name] = self._constant_nodes[node_input_name] | |||
| return node_weight | |||
| def _collect_ipt_shape_of_each_node(self, input_shape): | |||
| """ | |||
| Collect input tensor shape of each node. | |||
| Args: | |||
| input_shape (tuple): Input shape. | |||
| """ | |||
| input_node = InputNode(input_shape) | |||
| input_node_name = "{}InputNode" | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||
| input_node.set_scope_name(node.scope_name) | |||
| node.precursor_nodes.insert(0, ipt_nd_name) | |||
| input_node.set_successor_nodes(node_name) | |||
| self._shape_dict[ipt_nd_name] = input_node.output_shape | |||
| if not self._shape_dict[node_name]: | |||
| self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE | |||
| ipt_shape = [] | |||
| for p_nd in node.precursor_nodes: | |||
| shp = self._shape_dict.get(p_nd) | |||
| ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) | |||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | |||
| def _generate_module(self): | |||
| """Generate modules.""" | |||
| module_dict = dict() | |||
| for node_key, _ in self._nodes_collection.items(): | |||
| node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE) | |||
| if len(node_key_in_scope) < MIN_SCOPE_LENGTH: | |||
| continue | |||
| for idx in range(1, len(node_key_in_scope)): | |||
| node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx]) | |||
| node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1]) | |||
| if not module_dict.get(node_key_module, None): | |||
| module_dict[node_key_module] = {node_name} | |||
| else: | |||
| module_dict[node_key_module].add(node_name) | |||
| return module_dict | |||
| def _check_multi_ipt_opt(self): | |||
| """Check whether multi-input exists.""" | |||
| module_dict = self._generate_module() | |||
| for _, nodes_per_module in module_dict.items(): | |||
| prcs_nodes_out_from_module = set() | |||
| for node_name in nodes_per_module: | |||
| if re.search(r"[\d]+[&][\d]+", node_name): | |||
| self._is_multi_opt_graph = True | |||
| return True | |||
| node = self._nodes_collection.get(node_name, None) | |||
| if node: | |||
| prcs_nodes = node.precursor_nodes | |||
| else: | |||
| continue | |||
| for prcs_node in prcs_nodes: | |||
| if prcs_node not in nodes_per_module: | |||
| prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1]) | |||
| if prcs_node_module not in nodes_per_module: | |||
| prcs_nodes_out_from_module.add(prcs_node) | |||
| if len(prcs_nodes_out_from_module) > 1: | |||
| return True | |||
| return False | |||
| def _unmerge_multi_ipt_opt_script(self): | |||
| """Unmerge all submodule.""" | |||
| if self._check_multi_ipt_opt() or self._has_eliminated_nodes: | |||
| for node_key, node_inst in deepcopy(self._nodes_collection).items(): | |||
| prsc_nodes = node_inst.precursor_nodes | |||
| scsr_nodes = node_inst.successor_nodes | |||
| node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph | |||
| node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], | |||
| prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) | |||
| for prsc_node in deepcopy(prsc_nodes)] | |||
| node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0], | |||
| scsr_node.split(SEPARATOR_IN_SCOPE)[-1])) | |||
| for scsr_node in deepcopy(scsr_nodes)] | |||
| reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], | |||
| node_key.split(SEPARATOR_IN_SCOPE)[-1])) | |||
| del self._nodes_collection[node_key] | |||
| self._nodes_collection[reduce_node_key] = node_inst | |||
| for node_key, shape in deepcopy(self._shape_dict).items(): | |||
| reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], | |||
| node_key.split(SEPARATOR_IN_SCOPE)[-1])) | |||
| del self._shape_dict[node_key] | |||
| self._shape_dict[reduce_node_key] = shape | |||
| def sub_graph_merging(self): | |||
| """ | |||
| Merge split operation into one. | |||
| """ | |||
| raise NotImplementedError() | |||
| def build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| Args: | |||
| src (str): Source node name. | |||
| tgt (str): Target node name. | |||
| """ | |||
| # If src and tgt are the same node, src not in node_collection or | |||
| # tgt not in node_collection, then skip this edge. | |||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | |||
| if src.split(':')[0] not in self._nodes_collection: | |||
| log.warning("Graph construct a self-loop node %s. Ignored.", src) | |||
| return | |||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | |||
| self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | |||
| if src not in self._nodes_collection[tgt].precursor_nodes: | |||
| self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src) | |||
| @staticmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| """ | |||
| Load checkpoint. | |||
| Args: | |||
| ckpt_path (str): Checkpoint file path. | |||
| Returns: | |||
| dict, weights in model. | |||
| """ | |||
| @staticmethod | |||
| def load_metadata(**kwargs): | |||
| """ | |||
| Load graph metadata. | |||
| """ | |||
| err_msg = "class `PyTorchGraph` has not implemented " \ | |||
| "`load_metadata()`." | |||
| log.error(err_msg) | |||
| raise NotImplementedError(err_msg) | |||
| @staticmethod | |||
| def load_graph(graph_path: str, **kwargs): | |||
| """ | |||
| Load graph. | |||
| Args: | |||
| graph_path (str): Graph path. | |||
| Returns: | |||
| object, pytorch model. | |||
| """ | |||
| torch_model = PyTorchGraphParser.parse(graph_path) | |||
| return torch_model | |||
| @staticmethod | |||
| def get_node_id(node): | |||
| """ | |||
| Get node id using regular expr. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| Returns: | |||
| str, node id. | |||
| """ | |||
| node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||
| node_id = re.findall(r"[%](.*?) [:]", node_title) | |||
| return node_id | |||
| @@ -1,236 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| import re | |||
| from .base import GraphNode | |||
| from ..common.utils import is_converted | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||
| class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| PyTorch graph node. | |||
| Args: | |||
| node (torch._C.Node): Node in raw PyTorch graph. | |||
| """ | |||
| _type_frozen = False | |||
| _module_name_frozen = False | |||
| def __init__(self, node=None, weight=None): | |||
| super(PyTorchGraphNode, self).__init__(node=node) | |||
| self._op_params = self._get_raw_params(node) | |||
| self._op_name = node.kind() if node else None | |||
| self._scope_name = node.scopeName() if node else None | |||
| self._weight = weight | |||
| self._ipt_var_names, self._opt_var_names \ | |||
| = self._extract_ipt_opt_var_names() if node else (list(), list()) | |||
| def _extract_ipt_opt_var_names(self): | |||
| """Extract ipt and opt var names.""" | |||
| node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join( | |||
| str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:] | |||
| ) | |||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||
| node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||
| node_outputs = re.findall(r"[%](.*?) [:]", node_title) | |||
| return node_inputs, node_outputs | |||
| def clear_args_of_declaration(self): | |||
| """ | |||
| Clear `self._args_in_code`. | |||
| """ | |||
| self._args_in_code = dict() | |||
| def _get_arg_name(self, arg, variable_name): | |||
| """ | |||
| Get arg name. | |||
| Args: | |||
| arg (str): Generate arg name. | |||
| Returns: | |||
| str, arg name in function or class declaration. | |||
| """ | |||
| return f"{arg}_{variable_name}" | |||
| @property | |||
| def is_in_multi_opt_graph(self): | |||
| return self._is_in_multi_opt_graph | |||
| @is_in_multi_opt_graph.setter | |||
| def is_in_multi_opt_graph(self, multi_opt_state): | |||
| self._is_in_multi_opt_graph = multi_opt_state | |||
| @property | |||
| def hash_key(self): | |||
| """ | |||
| Return unique hash key of current node. | |||
| Returns: | |||
| str, hash key. | |||
| """ | |||
| if self._node_type not in {NodeType.CLASS.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.MODULE.value}: | |||
| self._hash_key = self._op_name.lower() | |||
| return self._hash_key | |||
| @hash_key.setter | |||
| def hash_key(self, h): | |||
| """ | |||
| Setter of hash key. | |||
| Args: | |||
| h (str): Key. | |||
| """ | |||
| self._hash_key = h | |||
| @property | |||
| def op_name(self): | |||
| """ | |||
| Op name in torch. | |||
| Returns: | |||
| str, op name. | |||
| """ | |||
| return self._op_name | |||
| @op_name.setter | |||
| def op_name(self, name): | |||
| """ | |||
| Setter of op name. | |||
| Args: | |||
| name(str): op_name. | |||
| """ | |||
| self._op_name = name | |||
| @property | |||
| def real_name(self): | |||
| return | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| """ | |||
| Add the node input shape. | |||
| Args: | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (list): Output variable names in construct. | |||
| code_fragment (CodeFragment): CodeFragment instance. | |||
| Returns: | |||
| Union[str, str], declare in init and call in construct. | |||
| """ | |||
| operator = code_fragment.operation | |||
| args = self.args_in_code | |||
| settings = code_fragment.code_setting | |||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||
| ipt_args_in_construct, settings) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if SEPARATOR_IN_ONNX_OP in operator: | |||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| declare = f"self.{variable_name} = {operator}({expr})" | |||
| call = f"{', '.join([output for output in output_var])}" \ | |||
| f" = self.{variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||
| def to_ir(self): | |||
| """ | |||
| No need to implement for now. | |||
| """ | |||
| raise NotImplementedError | |||
| def _get_raw_params(self, node): | |||
| """ | |||
| Get params in onnx. | |||
| Args: | |||
| node (Any): Node. | |||
| Returns: | |||
| dict, raw params. | |||
| """ | |||
| from .torch_utils import getitem_of_node | |||
| raw_params = dict() | |||
| if not node: | |||
| return raw_params | |||
| for k in node.attributeNames(): | |||
| raw_params[k] = getitem_of_node(node, k) | |||
| return raw_params | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| """ | |||
| Replace actual parameter with formal parameter. | |||
| Args: | |||
| src_arg (str): Original arg name. | |||
| tgt_arg (str): Target arg name. | |||
| """ | |||
| self._args_in_code[src_arg] = tgt_arg | |||
| @staticmethod | |||
| def _extract_var_name(scope_name: str): | |||
| """ | |||
| Extract variable name from scope name. | |||
| """ | |||
| if not scope_name: | |||
| return None | |||
| var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower() | |||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||
| RIGHT_BUCKET, "") | |||
| return var | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,6 +18,7 @@ from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from .optimizer import OnnxSimplify | |||
| from ...common.exceptions import ModelNotSupportError | |||
| @@ -38,7 +39,6 @@ class PyTorchGraphParser(GraphParser): | |||
| Returns: | |||
| object, torch model. | |||
| """ | |||
| torch = import_module("torch") | |||
| if not os.path.exists(model_path): | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| @@ -47,14 +47,66 @@ class PyTorchGraphParser(GraphParser): | |||
| raise error | |||
| try: | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| onnx_model_sim = cls._convert_pytorch_graph_to_onnx( | |||
| model_path, kwargs['sample_shape'], opset_version=11) | |||
| return onnx_model_sim | |||
| except ModuleNotFoundError: | |||
| error_msg = "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| raise error | |||
| return model | |||
| @staticmethod | |||
| def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None): | |||
| """ | |||
| Convert Pytorch model to ONNX model. | |||
| Args: | |||
| model_path (str): Path to the Pytorch model. | |||
| sample_shape (tuple): Input shape to generate onnx model. | |||
| opset_version (int): Op set version of onnx. | |||
| """ | |||
| torch = import_module('torch') | |||
| has_cuda = torch.cuda.is_available() | |||
| if has_cuda: | |||
| model = torch.load(f=model_path).cuda() | |||
| dump_input = torch.randn(*sample_shape, device='cuda') | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| dump_input = torch.randn(*sample_shape, device='cpu') | |||
| if isinstance(model, torch.nn.DataParallel): | |||
| raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.') | |||
| torch_onnx = import_module('torch.onnx') | |||
| operator_export_types = getattr(torch_onnx, 'OperatorExportTypes') | |||
| utils = import_module('torch.onnx.utils') | |||
| model_to_graph = getattr(utils, '_model_to_graph') | |||
| symbolic_helper = import_module('torch.onnx.symbolic_helper') | |||
| default_onnx_opset_version = getattr(symbolic_helper, '_default_onnx_opset_version') | |||
| set_opset_version = getattr(symbolic_helper, '_set_opset_version') | |||
| set_operator_export_type = getattr(symbolic_helper, '_set_operator_export_type') | |||
| if not opset_version: | |||
| opset_version = default_onnx_opset_version | |||
| operator_export_type = operator_export_types.ONNX | |||
| set_opset_version(opset_version) | |||
| set_operator_export_type(operator_export_type) | |||
| graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True) | |||
| export_onnx = getattr(graph, '_export_onnx') | |||
| proto, _ = export_onnx( | |||
| params_dict, opset_version, dict(), False, | |||
| operator_export_type, True, False, dict(), | |||
| True, False) | |||
| onnx = import_module('onnx') | |||
| onnx_model = onnx.load_model_from_string(proto) | |||
| onnx_simplify = OnnxSimplify() | |||
| onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape) | |||
| return onnx_model_sim | |||
| @@ -1,105 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define pytorch tracer context manager.""" | |||
| import importlib | |||
| from torch.nn import Module | |||
| from torch.onnx.utils import _trace | |||
| from torch.onnx.utils import _node_getitem | |||
| from torch.onnx.symbolic_helper import _set_opset_version | |||
| SCRIPT_METHOD = getattr(importlib.import_module("torch._C"), | |||
| "ScriptMethod") | |||
| onnx_tracer = _trace | |||
| getitem_of_node = _node_getitem | |||
| set_opset_version = _set_opset_version | |||
| def unique_state_dict(model): | |||
| """ | |||
| Wrapper of torch.jit._unique_state_dict. | |||
| Args: | |||
| model (Module): Torch model. | |||
| Returns: | |||
| dict, params. | |||
| """ | |||
| from torch.jit import _unique_state_dict | |||
| return _unique_state_dict(model) | |||
| def create_autograd_variable(tensor): | |||
| """ | |||
| Create autograd variable to trace the whole graph. | |||
| Args: | |||
| tensor (torch.Tensor): Tensor. | |||
| Returns: | |||
| torch.autograd.Variable, variable. | |||
| """ | |||
| variable = getattr(importlib.import_module("torch.autograd"), "Variable") | |||
| return variable(tensor, requires_grad=False) | |||
| class OverloadTorchModuleTemporarily: | |||
| """ | |||
| Fix bugs in new version of pytorch. | |||
| PyTorch official solution. | |||
| """ | |||
| def __init__(self): | |||
| self.backup = None | |||
| def __enter__(self): | |||
| def _tracing_name(traced_module, tracing_state): | |||
| traced_module_stack = getattr(tracing_state, "_traced_module_stack") | |||
| if not traced_module_stack: | |||
| return None | |||
| module = traced_module_stack[-1] | |||
| for name, child in module.named_children(): | |||
| if child is traced_module: | |||
| return name | |||
| return None | |||
| def _slow_forward(self_, *inputs, **kwargs): | |||
| tracing_state = getattr(importlib.import_module("torch._C"), | |||
| "_get_tracing_state")() | |||
| if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD): | |||
| return self_.forward(*inputs, **kwargs) | |||
| if not hasattr(tracing_state, '_traced_module_stack'): | |||
| tracing_state._traced_module_stack = [] | |||
| name = _tracing_name(self_, tracing_state) | |||
| get_name_func = getattr(self_, "_get_name") | |||
| if name: | |||
| tracing_state.push_scope('%s[%s]' % (get_name_func(), name)) | |||
| else: | |||
| tracing_state.push_scope(get_name_func()) | |||
| tracing_state._traced_module_stack.append(self_) | |||
| try: | |||
| result = self_.forward(*inputs, **kwargs) | |||
| finally: | |||
| tracing_state.pop_scope() | |||
| tracing_state._traced_module_stack.pop() | |||
| return result | |||
| self.backup = getattr(Module, "_slow_forward") | |||
| setattr(Module, '_slow_forward', _slow_forward) | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| setattr(Module, '_slow_forward', self.backup) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================== | |||
| """Test name manager module.""" | |||
| from unittest import TestCase | |||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \ | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import NameMgr, GlobalVarNameMgr, \ | |||
| global_op_namespace | |||
| @@ -1,15 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Unit test for mindconvert.graph_based_converter.hierarchical_tree interface.""" | |||
| @@ -1,177 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test hierarchical tree module.""" | |||
| import os | |||
| import shutil | |||
| from unittest import mock | |||
| import pytest | |||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType | |||
| from tests.ut.mindconverter.graph_based_converter.conftest import TEST_BASE_PATH | |||
| class TestHierarchicalTree: | |||
| """Test the class of HierarchicalTree.""" | |||
| def test_tree_identifier(self): | |||
| """Test tree_identifier""" | |||
| tree = HierarchicalTree() | |||
| assert isinstance(tree.tree_identifier, str) | |||
| @mock.patch( | |||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||
| def test_insert(self, get_raw_params): | |||
| """Test insert""" | |||
| get_raw_params.return_value = [] | |||
| tree = HierarchicalTree() | |||
| pt_node = PyTorchGraphNode() | |||
| tree.insert(pt_node, 'ResNet') | |||
| assert tree.root == 'ResNet' | |||
| def test_remove(self): | |||
| """Test remove function.""" | |||
| tree = HierarchicalTree() | |||
| tree.create_node( | |||
| tag='node_root', | |||
| identifier='root', | |||
| parent=None, | |||
| data=None | |||
| ) | |||
| node = tree.get_node('root') | |||
| tree.remove(node) | |||
| assert tree.root is None | |||
| @mock.patch( | |||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||
| def test_shrink(self, get_raw_params): | |||
| """Test shrink function.""" | |||
| params = {'root': {}, | |||
| 'root/child0': {}, | |||
| 'root/child0/child1': {}} | |||
| tree = self._create_tree(get_raw_params=get_raw_params, params=params) | |||
| node = tree.get_node('root/child0') | |||
| tree.shrink(node) | |||
| assert tree.leaves()[0].tag == 'child1' | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'tree_params': {'root': {'op_name': 'Root', | |||
| 'precursor_nodes': [], | |||
| 'successor_nodes': ['root/relu'], | |||
| 'node_type': NodeType.MODULE.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 1, 224, 224]}, | |||
| 'root/relu': {'op_name': 'onnx::Relu', | |||
| 'precursor_nodes': ['root'], | |||
| 'successor_nodes': ['root/unknown'], | |||
| 'node_type': NodeType.OPERATION.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 3, 224, 224]}, | |||
| 'root/unknown': {'op_name': 'onnx::Unknown', | |||
| 'precursor_nodes': ['root/relu'], | |||
| 'successor_nodes': [], | |||
| 'node_type': NodeType.OPERATION.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 1, 224, 224]}}, | |||
| 'report_dir': 'report_folder' | |||
| }, { | |||
| 'tree_params': {'root': {'op_name': 'Root', | |||
| 'precursor_nodes': [], | |||
| 'successor_nodes': ['root/relu'], | |||
| 'node_type': NodeType.MODULE.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 1, 224, 224]}, | |||
| 'root/relu': {'op_name': 'onnx::Relu', | |||
| 'precursor_nodes': ['root'], | |||
| 'successor_nodes': ['root/unknown'], | |||
| 'node_type': NodeType.OPERATION.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 3, 224, 224]}, | |||
| 'root/unknown': {'op_name': 'onnx::Unknown', | |||
| 'precursor_nodes': ['root/relu'], | |||
| 'successor_nodes': [], | |||
| 'node_type': NodeType.OPERATION.value, | |||
| 'input_shape': [1, 3, 224, 224], | |||
| 'output_shape': [1, 1, 224, 224]}}, | |||
| 'report_dir': None | |||
| }]) | |||
| @mock.patch( | |||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||
| def test_save_source_file(self, get_raw_params, params): | |||
| """Test save_source_file function.""" | |||
| tree_params = params['tree_params'] | |||
| out_folder = 'out_folder' | |||
| report_folder = params['report_dir'] | |||
| model_name = 'model_name' | |||
| mapper = ONNXToMindSporeMapper() | |||
| tree = self._create_tree(get_raw_params=get_raw_params, params=tree_params) | |||
| tree.save_source_files(out_folder, mapper, model_name, report_folder) | |||
| out_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) | |||
| report_folder_test = report_folder if report_folder else out_folder | |||
| report_path = os.path.realpath( | |||
| os.path.join(report_folder_test, f"report_of_{model_name}.txt")) | |||
| try: | |||
| assert os.path.exists(out_path) | |||
| assert os.path.exists(report_path) | |||
| with open(out_path, 'r') as out_r: | |||
| code = out_r.read() | |||
| assert 'nn.ReLU' in code | |||
| assert 'onnx.Unknown' in code | |||
| with open(report_path, 'r') as report_r: | |||
| report = report_r.read() | |||
| assert "[UnConvert] 'onnx::Unknown' didn't convert." in report | |||
| assert "Converted Rate: 50.00%." in report | |||
| finally: | |||
| shutil.rmtree(out_folder) | |||
| if report_folder: | |||
| shutil.rmtree(report_folder) | |||
| @staticmethod | |||
| def _create_node(key, val, weight, input_shape, output_shape): | |||
| """Create node.""" | |||
| node = PyTorchGraphNode(weight=weight) | |||
| node.add_input_and_output_shape(input_shape, output_shape) | |||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||
| return node | |||
| @staticmethod | |||
| def _create_tree(get_raw_params, params): | |||
| """Create tree.""" | |||
| tree = HierarchicalTree() | |||
| for key, val in params.items(): | |||
| input_shape = val['input_shape'] if val.get('input_shape') else [] | |||
| output_shape = val['output_shape'] if val.get('output_shape') else [] | |||
| get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() | |||
| weight = val['weight'] if val.get('weight') else None | |||
| node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape) | |||
| tree.create_node( | |||
| tag=node.tag, | |||
| identifier=key, | |||
| parent='/'.join(key.split('/')[:-1]) if len(key.split('/')) > 1 else None, | |||
| data=node | |||
| ) | |||
| return tree | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -12,4 +12,4 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Unit test for mindconvert.graph_based_converter.mapper interface.""" | |||
| """Unit test for mindconverter.graph_based_converter.mapper interface.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,7 +18,6 @@ import pytest | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from tests.utils import mindspore | |||
| class TestMappers: | |||
| @@ -30,7 +29,7 @@ class TestMappers: | |||
| 'group': 1, | |||
| 'pads': [1, 2, 3, 4], | |||
| 'strides': [1, 1]}, | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}}, | |||
| 'weights': {'weight': np.zeros((64, 3, 1, 1), dtype=np.int32)}}, | |||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 64, | |||
| @@ -47,7 +46,7 @@ class TestMappers: | |||
| 'group': 1, | |||
| 'pads': [0, 0, 0, 0], | |||
| 'strides': [1, 1]}, | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}}, | |||
| 'weights': {'weight': np.zeros((64, 3, 2, 2), dtype=np.int32)}}, | |||
| 'expected_output': {'converter_name': 'nn.Conv2d', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 64, | |||
| @@ -61,8 +60,8 @@ class TestMappers: | |||
| }, { | |||
| 'input': {'op_name': 'onnx::Gemm', | |||
| 'params': dict(), | |||
| 'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)), | |||
| 'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}}, | |||
| 'weights': {'weight': np.zeros((10, 3), dtype=np.int32), | |||
| 'bias': np.zeros((10, 1), dtype=np.int32)}}, | |||
| 'expected_output': {'converter_name': 'nn.Dense', | |||
| 'converted_params': {'in_channels': 3, | |||
| 'out_channels': 10, | |||