From: @moran3 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -194,6 +194,9 @@ def convert_bytes_string_to_string(bytes_str): | |||||
| def get_framework_type(model_path): | def get_framework_type(model_path): | ||||
| """Get framework type.""" | """Get framework type.""" | ||||
| if model_path.endswith('.onnx'): | |||||
| return FrameworkType.PYTORCH.value | |||||
| try: | try: | ||||
| with open(model_path, 'rb') as f: | with open(model_path, 'rb') as f: | ||||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -24,10 +24,12 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver | |||||
| save_code_file_and_report, get_framework_type | save_code_file_and_report, get_framework_type | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | ||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | ||||
| from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ | ||||
| BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError | BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| os.umask(permissions << 3 | permissions) | os.umask(permissions << 3 | permissions) | ||||
| @@ -62,6 +64,7 @@ def torch_installation_validation(func): | |||||
| """ | """ | ||||
| def _f(graph_path: str, sample_shape: tuple, | def _f(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | |||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): | if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): | ||||
| @@ -93,6 +96,7 @@ def torch_installation_validation(func): | |||||
| sys.exit(0) | sys.exit(0) | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| input_nodes=input_nodes, output_nodes=output_nodes, | |||||
| output_folder=output_folder, report_folder=report_folder) | output_folder=output_folder, report_folder=report_folder) | ||||
| return _f | return _f | ||||
| @@ -182,6 +186,7 @@ def _extract_model_name(model_path): | |||||
| @SourceFilesSaveError.uniform_catcher() | @SourceFilesSaveError.uniform_catcher() | ||||
| @GeneratorError.uniform_catcher() | @GeneratorError.uniform_catcher() | ||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | |||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| """ | """ | ||||
| PyTorch to MindSpore based on Graph. | PyTorch to MindSpore based on Graph. | ||||
| @@ -189,26 +194,18 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| Args: | Args: | ||||
| graph_path (str): Graph file path. | graph_path (str): Graph file path. | ||||
| sample_shape (tuple): Input shape of the model. | sample_shape (tuple): Input shape of the model. | ||||
| input_nodes (str): Input node(s) of the model. | |||||
| output_nodes (str): Output node(s) of the model. | |||||
| output_folder (str): Output folder. | output_folder (str): Output folder. | ||||
| report_folder (str): Report output folder path. | report_folder (str): Report output folder path. | ||||
| """ | """ | ||||
| third_party_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||||
| hierarchical_tree_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | |||||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||||
| model_name=model_name, | |||||
| report_folder=report_folder) | |||||
| code_fragments = generator_inst.generate() | |||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||||
| @tf_installation_validation | @tf_installation_validation | ||||
| @@ -230,18 +227,13 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| output_nodes(str): Output node(s) of the model. | output_nodes(str): Output node(s) of the model. | ||||
| output_folder(str): Output folder. | output_folder(str): Output folder. | ||||
| report_folder(str): Report output folder path. | report_folder(str): Report output folder path. | ||||
| """ | """ | ||||
| third_party_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||||
| batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), | |||||
| "batch_add_nodes") | |||||
| # Close unnecessary log. | # Close unnecessary log. | ||||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | ||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| code_fragments = generator_inst.generate() | code_fragments = generator_inst.generate() | ||||
| @@ -255,7 +247,6 @@ def main_graph_base_converter(file_config): | |||||
| Args: | Args: | ||||
| file_config (dict): The config of file which to convert. | file_config (dict): The config of file which to convert. | ||||
| """ | """ | ||||
| graph_path = file_config['model_file'] | graph_path = file_config['model_file'] | ||||
| frame_type = get_framework_type(graph_path) | frame_type = get_framework_type(graph_path) | ||||
| @@ -263,8 +254,12 @@ def main_graph_base_converter(file_config): | |||||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | ||||
| if frame_type == FrameworkType.PYTORCH.value: | if frame_type == FrameworkType.PYTORCH.value: | ||||
| check_params = ['input_nodes', 'output_nodes'] | |||||
| check_params_exist(check_params, file_config) | |||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | graph_based_converter_pytorch_to_ms(graph_path=graph_path, | ||||
| sample_shape=file_config['shape'], | sample_shape=file_config['shape'], | ||||
| input_nodes=file_config['input_nodes'], | |||||
| output_nodes=file_config['output_nodes'], | |||||
| output_folder=file_config['outfile_dir'], | output_folder=file_config['outfile_dir'], | ||||
| report_folder=file_config['report_dir']) | report_folder=file_config['report_dir']) | ||||
| elif frame_type == FrameworkType.TENSORFLOW.value: | elif frame_type == FrameworkType.TENSORFLOW.value: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -213,10 +213,11 @@ class ArgsTranslationHelper: | |||||
| Returns: | Returns: | ||||
| list, name of args to be formal. | list, name of args to be formal. | ||||
| """ | """ | ||||
| ret = list() | |||||
| if len(args_translators) < 2: | if len(args_translators) < 2: | ||||
| # only one args_translator provided, no formal args. | # only one args_translator provided, no formal args. | ||||
| return None | |||||
| ret = [] | |||||
| return ret | |||||
| base_args_t = args_translators[0] | base_args_t = args_translators[0] | ||||
| for arg_name, arg_val in base_args_t.actual_args.items(): | for arg_name, arg_val in base_args_t.actual_args.items(): | ||||
| for args_t in args_translators[1:]: | for args_t in args_translators[1:]: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -24,7 +24,7 @@ from .module_struct import ModuleStruct | |||||
| from .args_translator import ArgsTranslationHelper | from .args_translator import ArgsTranslationHelper | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ...common.exceptions import GeneratorError | from ...common.exceptions import GeneratorError | ||||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||||
| from ..common.name_mgr import GlobalVarNameMgr | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -22,7 +22,7 @@ from ..common.utils import get_dict_key_by_value | |||||
| from .args_translator import ArgsTranslation | from .args_translator import ArgsTranslation | ||||
| from ..common.code_fragment import ModuleFragment | from ..common.code_fragment import ModuleFragment | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ..hierarchical_tree.name_mgr import LocalVarNameMgr | |||||
| from ..common.name_mgr import LocalVarNameMgr | |||||
| class ModuleStruct: | class ModuleStruct: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,7 +18,6 @@ from collections import OrderedDict | |||||
| from .scope_utils import Scope | from .scope_utils import Scope | ||||
| from .args_translator import ArgsTranslation | from .args_translator import ArgsTranslation | ||||
| from ..common.code_fragment import CodeFragment | from ..common.code_fragment import CodeFragment | ||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ..constant import InputType | from ..constant import InputType | ||||
| @@ -110,11 +109,6 @@ class NodeStruct: | |||||
| self.graph_node_ref = gn | self.graph_node_ref = gn | ||||
| self.scope_name = gn.scope_name | self.scope_name = gn.scope_name | ||||
| def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): | |||||
| """Update basic info from PyTorchGraphNode.""" | |||||
| self.node_type = "PyTorchGraphNode" | |||||
| self._update_basics_from_gn(gn) | |||||
| def _update_from_onnx_gn(self, gn: OnnxGraphNode): | def _update_from_onnx_gn(self, gn: OnnxGraphNode): | ||||
| """Update basic info from OnnxGraphNode.""" | """Update basic info from OnnxGraphNode.""" | ||||
| self.node_type = "OnnxGraphNode" | self.node_type = "OnnxGraphNode" | ||||
| @@ -177,9 +171,8 @@ class NodeStruct: | |||||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | ||||
| force_ready (bool): Force this NodeStruct is ready to generate. | force_ready (bool): Force this NodeStruct is ready to generate. | ||||
| """ | """ | ||||
| if isinstance(arg, PyTorchGraphNode): | |||||
| self._update_from_pytorch_gn(arg) | |||||
| elif isinstance(arg, OnnxGraphNode): | |||||
| if isinstance(arg, OnnxGraphNode): | |||||
| self._update_from_onnx_gn(arg) | self._update_from_onnx_gn(arg) | ||||
| elif isinstance(arg, (dict, OrderedDict)): | elif isinstance(arg, (dict, OrderedDict)): | ||||
| self._update_from_mapper(arg) | self._update_from_mapper(arg) | ||||
| @@ -246,7 +239,6 @@ class NodeStruct: | |||||
| """Return the output variable name of current node.""" | """Return the output variable name of current node.""" | ||||
| return "{}_opt".format(self.ms_var_name).lower() | return "{}_opt".format(self.ms_var_name).lower() | ||||
| @property | @property | ||||
| def args_translator(self): | def args_translator(self): | ||||
| """Return the args translator of this Node.""" | """Return the args translator of this Node.""" | ||||
| @@ -1,89 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Hierarchical tree module.""" | |||||
| __all__ = ["HierarchicalTreeFactory"] | |||||
| import re | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .hierarchical_tree import HierarchicalTree | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError | |||||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||||
| """ | |||||
| Rename the node name by combining scope name and its original name. | |||||
| Args: | |||||
| node (OnnxGraphNode): OnnxGraphNode instance. | |||||
| node_name (str): node name saved in Graph. | |||||
| Returns: | |||||
| str, re-formatted node name. | |||||
| """ | |||||
| scope_name = node.scope_name | |||||
| new_name = None | |||||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||||
| match = re.match(regex, scope_name) | |||||
| parent = match.group("parent") | |||||
| node_name = '$' + node_name.replace('/', '::') + '$' | |||||
| if scope_name: | |||||
| new_name = parent + node_name | |||||
| if new_name: | |||||
| return new_name | |||||
| return node_name | |||||
| class HierarchicalTreeFactory: | |||||
| """Hierarchical tree factory.""" | |||||
| @classmethod | |||||
| @TreeNodeInsertError.check_except("Tree node inserts failed.") | |||||
| def create(cls, graph): | |||||
| """ | |||||
| Factory method of hierarchical tree. | |||||
| Args: | |||||
| graph: Graph obj. | |||||
| Returns: | |||||
| HierarchicalTree, tree. | |||||
| """ | |||||
| tree = HierarchicalTree() | |||||
| node_scope_name = dict() | |||||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||||
| node_inst = graph.get_node(node_name) | |||||
| node_input = graph.get_input_shape(node_name) | |||||
| node_output = graph.get_output_shape(node_name) | |||||
| if node_input != 0 and not node_input: | |||||
| err_msg = f"This model is not supported now. " \ | |||||
| f"Cannot find {node_name}'s input shape." | |||||
| error = NodeInputMissingError(err_msg) | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| if isinstance(node_inst, OnnxGraphNode): | |||||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||||
| node_scope_name[node_name] = node_name_with_scope | |||||
| node_name = node_name_with_scope | |||||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||||
| tree.insert(node_inst, node_name) | |||||
| if node_scope_name: | |||||
| return tree, node_scope_name | |||||
| return tree | |||||
| @@ -1,796 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define hierarchical tree.""" | |||||
| import os | |||||
| import stat | |||||
| from copy import deepcopy | |||||
| from typing import NoReturn, Union | |||||
| from queue import Queue | |||||
| from yapf.yapflib.yapf_api import FormatCode | |||||
| from treelib import Tree, Node | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||||
| from ..common.utils import is_converted, save_code_file_and_report | |||||
| from ..mapper.base import Mapper | |||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS | |||||
| from ..constant import CodeFormatConfig | |||||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||||
| from ..constant import NodeType | |||||
| from ..report_generator import ReportGenerator | |||||
| from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError | |||||
| class HierarchicalTree(Tree): | |||||
| """Define hierarchical tree.""" | |||||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||||
| modes = stat.S_IRUSR | stat.S_IWUSR | |||||
| modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | |||||
| _root_created = False | |||||
| ROOT_LEVEL = 0 | |||||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||||
| def __init__(self): | |||||
| super(HierarchicalTree, self).__init__() | |||||
| self._hierarchical_order = dict() | |||||
| # Manage mapping of unique key and module name. | |||||
| self._merged_module = dict() | |||||
| # Manage mapping of unique key and module args. | |||||
| self._merged_module_args = dict() | |||||
| # Record creation of module with unique key. | |||||
| self._created_module = dict() | |||||
| # Manage module name to used. | |||||
| self._module_mgr = ModuleNameMgr() | |||||
| # Manage variable name in a module. | |||||
| self._vars_mgr_in_module = dict() | |||||
| self._module_vars = dict() | |||||
| # scope name mapping record for easy node searching | |||||
| self._scope_name_map = dict() | |||||
| self.code_fragment_recorder = dict() | |||||
| @property | |||||
| def tree_identifier(self): | |||||
| """ | |||||
| Return identifier of tree. | |||||
| Returns: | |||||
| tree, id of tree. | |||||
| """ | |||||
| return self.identifier | |||||
| def get_node(self, nid): | |||||
| """Override get_node method to support tf ver. generated scope.""" | |||||
| if nid is None or not self.contains(nid): | |||||
| if self._scope_name_map and nid in self._scope_name_map: | |||||
| nid = self._scope_name_map.get(nid) | |||||
| else: | |||||
| return None | |||||
| return self._nodes[nid] | |||||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): | |||||
| """ | |||||
| Insert node into hierarchical tree. | |||||
| Args: | |||||
| node_name (str): Node name. | |||||
| node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | |||||
| """ | |||||
| scopes = node_name.split(SEPARATOR_IN_SCOPE) | |||||
| for idx, scope in enumerate(scopes): | |||||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | |||||
| identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | |||||
| try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | |||||
| if parent else scope | |||||
| if self.contains(try_parent): | |||||
| # Whether current node existed. | |||||
| parent = try_parent | |||||
| if not parent and not self._root_created: | |||||
| # If no root node, then create it and mark it. | |||||
| parent = None | |||||
| self._root_created = True | |||||
| elif not parent and self._root_created: | |||||
| # Already have root node, skip it. | |||||
| continue | |||||
| if not self.contains(identifier): | |||||
| # Insert node into tree. | |||||
| if isinstance(node, OnnxGraphNode): | |||||
| tgt_node = node if idx == len( | |||||
| scopes) - 1 else OnnxGraphNode() | |||||
| else: | |||||
| tgt_node = node if idx == len( | |||||
| scopes) - 1 else PyTorchGraphNode() | |||||
| tgt_node.successor_nodes = node.successor_nodes | |||||
| tgt_node.precursor_nodes = node.precursor_nodes | |||||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | |||||
| else NodeType.MODULE).value | |||||
| tgt_node.variable_name = self._get_var_name(identifier) | |||||
| self.create_node( | |||||
| tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], | |||||
| identifier=identifier, | |||||
| parent=parent, | |||||
| data=tgt_node | |||||
| ) | |||||
| def remove(self, node: Node, keep_sub=False): | |||||
| """ | |||||
| Remove node into hierarchical tree. | |||||
| Args: | |||||
| node (Node): Node to be removed. | |||||
| keep_sub (bool): Whether keep sub-tree. | |||||
| """ | |||||
| if not keep_sub: | |||||
| self.remove_node(node.identifier) | |||||
| return | |||||
| def shrink(self, node: Node): | |||||
| """ | |||||
| Shrink sub-tree into one node. | |||||
| Use child node to replace its ancestor. | |||||
| Args: | |||||
| node (Node): List of nodes to be merged. | |||||
| """ | |||||
| node_name = node.identifier | |||||
| parent_node = self[node.predecessor(self.tree_identifier)] | |||||
| # Keep successors of parent. | |||||
| brothers = deepcopy(parent_node.successors(self.tree_identifier)) | |||||
| # Because shrink occurs when node has only one child, | |||||
| # so we take index-0. | |||||
| child = node.successors(self.tree_identifier)[0] | |||||
| self.move_node(source=child, | |||||
| destination=node.predecessor(self.tree_identifier)) | |||||
| self.remove(node) | |||||
| brothers[brothers.index(node_name)] = child | |||||
| parent_node.set_successors(brothers, tree_id=self.tree_identifier) | |||||
| def save_source_files(self, out_folder: str, mapper: Mapper, | |||||
| model_name: str, | |||||
| report_folder: str = None, | |||||
| scope_name_map: dict = None) -> NoReturn: | |||||
| """ | |||||
| Save source codes to target folder. | |||||
| Args: | |||||
| report_folder (str): Report folder. | |||||
| mapper (Mapper): Mapper of third party framework and mindspore. | |||||
| model_name(str): Name of Converted model. | |||||
| out_folder (str): Output folder. | |||||
| scope_name_map(str): Scope name map of tensorflow. | |||||
| """ | |||||
| if scope_name_map: | |||||
| self._scope_name_map = scope_name_map | |||||
| try: | |||||
| self._adjust_structure() | |||||
| code_fragments = self._generate_codes(mapper) | |||||
| except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e: | |||||
| log.error("Error occur when generating codes.") | |||||
| raise e | |||||
| save_code_file_and_report(model_name, code_fragments, out_folder, report_folder) | |||||
| def _preprocess_node_args(self, node, module_key): | |||||
| """ | |||||
| Remove unused args. | |||||
| Args: | |||||
| node (Node): Node instance. | |||||
| module_key (str): Nodule key. | |||||
| Returns: | |||||
| Node, node. | |||||
| """ | |||||
| if module_key in self._merged_module_args: | |||||
| node = self._clear_unused_args( | |||||
| node, self._merged_module_args[module_key]) | |||||
| else: | |||||
| node.data.clear_args_of_declaration() | |||||
| return node | |||||
| def _postprocess_node_args(self, node, precursor_module_key): | |||||
| """ | |||||
| Post process args in node. | |||||
| Args: | |||||
| node (Node): Node instance. | |||||
| precursor_module_key (str): Parent node module name. | |||||
| Returns: | |||||
| Node, node. | |||||
| """ | |||||
| if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, | |||||
| NodeType.FUNC.value}: | |||||
| # If current node is class or function, then | |||||
| # remove unused args in __init__. | |||||
| cur_module_key = node.data.hash_key or self.hash_key(node) | |||||
| if cur_module_key in self._merged_module_args: | |||||
| node = self._clear_unused_args(node, | |||||
| self._merged_module_args[cur_module_key]) | |||||
| # `self._merged_module_args` records formal args. | |||||
| # We need to replace actual args. | |||||
| if precursor_module_key in self._merged_module_args: | |||||
| # If parent node is in `_merged_module_args`, then | |||||
| # replace current node args with arg name declared | |||||
| # in _merged_module_args. | |||||
| for arg in node.data.args_in_code.keys(): | |||||
| if arg in self._merged_module_args[precursor_module_key]: | |||||
| node.data.replace_with_arg(arg, arg) | |||||
| return node | |||||
| def _clear_unused_args(self, node, used_args): | |||||
| """ | |||||
| Clear unused args. | |||||
| Args: | |||||
| node (Node): Node. | |||||
| used_args (list): Args list. | |||||
| Returns: | |||||
| Node, node instance. | |||||
| """ | |||||
| args_in_code = list(node.data.args_in_code.keys()) | |||||
| for arg in args_in_code: | |||||
| ori_arg = arg.replace( | |||||
| f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" | |||||
| ) | |||||
| if ori_arg not in used_args: | |||||
| node.data.args_in_code.pop(arg) | |||||
| return node | |||||
| @ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.") | |||||
| @ReportGenerationError.check_except("Not find valid operators in converted script.") | |||||
| def _generate_codes(self, mapper): | |||||
| """ | |||||
| Generate code files. | |||||
| - 1. Generate args. | |||||
| - 2. Merge module. | |||||
| - 3. Pre-process node args. | |||||
| - 4. Post-process child node args. | |||||
| - 5. Generate class/func code. | |||||
| - 6. Merge code snippets. | |||||
| Args: | |||||
| mapper (Mapper): Mapper of third party operation and mindspore. | |||||
| Returns: | |||||
| Dict, codes. | |||||
| """ | |||||
| code_blocks = [get_imported_module()] | |||||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||||
| for depth in depths: | |||||
| node_collection = self._hierarchical_order[depth] | |||||
| for node_name in node_collection: | |||||
| # Traverse nodes in topological order. | |||||
| node = self.get_node(node_name) | |||||
| # 1. Generate args for each node in this level. | |||||
| if node.data.node_type == NodeType.MODULE.value: | |||||
| self._create_module_args_and_vars(node, mapper) | |||||
| if depth == depths[-1]: | |||||
| self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") | |||||
| # Module merging based on all nodes. | |||||
| self._module_merging() | |||||
| for depth in depths: | |||||
| node_collection = self._hierarchical_order[depth] | |||||
| snippets = set() | |||||
| for node_name in node_collection: | |||||
| nd_inst = self.get_node(node_name) | |||||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||||
| continue | |||||
| # Generate hash key for node. | |||||
| module_key = nd_inst.data.hash_key | |||||
| # Get code generation func. | |||||
| func, node_type = self._fetch_func_and_type(nd_inst) | |||||
| if module_key in self._created_module: | |||||
| # If the module has already been created, | |||||
| # then assign the created module name to current node, | |||||
| # and delete unused args. | |||||
| module_name = self._created_module[module_key] | |||||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||||
| self._preprocess_node_args(nd_inst, module_key) | |||||
| continue | |||||
| module_name = nd_inst.tag | |||||
| if node_type == NodeType.CLASS.value: | |||||
| module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||||
| # After node_type and module_name is frozen, | |||||
| # then it's unchangeable. | |||||
| module_name = self._module_mgr.get_name(module_name) | |||||
| self.code_fragment_recorder[nd_inst.identifier].operation = module_name | |||||
| self.code_fragment_recorder[nd_inst.identifier].node_type = node_type | |||||
| # 3. Pre-process node args. | |||||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||||
| # 4. Post-process child node args. | |||||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||||
| # 5. Generate code. | |||||
| snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) | |||||
| code_blocks.extend(snippets) | |||||
| if self._scope_name_map: # from tf. conversion | |||||
| c_blocks = [] | |||||
| for s in code_blocks: | |||||
| s = s.replace('$', '') | |||||
| c_blocks.append(s) | |||||
| code_blocks = c_blocks | |||||
| formatted_code, _ = FormatCode("".join(code_blocks), | |||||
| style_config=CodeFormatConfig.PEP8.value) | |||||
| report_generator = ReportGenerator() | |||||
| report = report_generator.gen_report(formatted_code) | |||||
| return {"model": (formatted_code, report)} | |||||
| def _fetch_func_and_type(self, node) -> Union[object, str]: | |||||
| """ | |||||
| Generate code snippet. | |||||
| Args: | |||||
| node (Node): Node. | |||||
| Returns: | |||||
| Union[object, str], code snippet func. | |||||
| """ | |||||
| def _is_func(): | |||||
| """ | |||||
| The correct thought is to check whether have more than one | |||||
| path in this block. | |||||
| """ | |||||
| nonlocal node | |||||
| if node.predecessor(self.tree_identifier) is None: | |||||
| return False | |||||
| tgt_type = {NodeType.MODULE.value, | |||||
| NodeType.FUNC.value, NodeType.CLASS.value} | |||||
| md_type_lst = [self.get_node(child).data.node_type | |||||
| for child in node.successors(self.tree_identifier)] | |||||
| diff_set = set(md_type_lst) - tgt_type | |||||
| return not diff_set | |||||
| if _is_func(): | |||||
| return self._generate_func_snippet, NodeType.FUNC.value | |||||
| return self._generate_class_snippet, NodeType.CLASS.value | |||||
| def _generate_func_snippet(self, node, func_name, func_key): | |||||
| """ | |||||
| Generate function snippet. | |||||
| Args: | |||||
| node (Node): Node inst. | |||||
| Returns: | |||||
| str, code snippet. | |||||
| """ | |||||
| definition = "" | |||||
| if func_key.lower() in self._merged_module_args and \ | |||||
| self._merged_module_args[func_key.lower()]: | |||||
| definition = ", ".join(self._merged_module_args[func_key.lower()]) | |||||
| module_list = [] | |||||
| for node_name in node.successors(self.tree_identifier): | |||||
| c_nd = self.get_node(node_name) | |||||
| operator = self.code_fragment_recorder[c_nd.identifier].operation | |||||
| if c_nd.data.node_type != NodeType.OPERATION.value: | |||||
| hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | |||||
| if hash_key in self._created_module: | |||||
| operator = self._created_module[hash_key] | |||||
| args = c_nd.data.args_in_code | |||||
| if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( | |||||
| self.code_fragment_recorder[c_nd.identifier].operation): | |||||
| args.update({"input_shape": c_nd.data.input_shape, | |||||
| "output_shape": c_nd.data.output_shape}) | |||||
| # Generate code statement. | |||||
| expr = ", ".join( | |||||
| [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" | |||||
| for k, v in args.items()] | |||||
| ) | |||||
| code_line = f"{operator}({expr})" | |||||
| module_list.append(code_line) | |||||
| body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | |||||
| snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \ | |||||
| f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \ | |||||
| f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \ | |||||
| f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)" | |||||
| definition = f"def {func_name}({definition}):{NEW_LINE}" | |||||
| # Mark the structure has been created. | |||||
| self._created_module[func_key.lower()] = func_name | |||||
| return f"{definition}{snippet}{NEW_LINE * 3}" | |||||
| def _generate_class_snippet(self, node, class_name, class_key): | |||||
| """ | |||||
| Generate class-type code snippet. | |||||
| Args: | |||||
| node (Node): Node. | |||||
| Returns: | |||||
| str, code snippet. | |||||
| """ | |||||
| super_call = f"super({class_name}, self).__init__()" | |||||
| if class_key.lower() in self._merged_module_args and \ | |||||
| self._merged_module_args[class_key.lower()]: | |||||
| args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | |||||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | |||||
| f"{args}):" \ | |||||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||||
| else: | |||||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||||
| init_block = [] | |||||
| construct_block = [] | |||||
| for idx, node_name in enumerate(node.successors(self.tree_identifier)): | |||||
| nd_inst = self.get_node(node_name) | |||||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| # Generate code statement. | |||||
| init, construct = self._generate_stat(nd_inst, node, idx) | |||||
| # support multiple construct and init block returns: | |||||
| if isinstance(construct, list): | |||||
| construct_block += construct | |||||
| else: | |||||
| construct_block.append(construct) | |||||
| if isinstance(init, list): | |||||
| init_block += init | |||||
| else: | |||||
| init_block.append(init) | |||||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ | |||||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||||
| init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | |||||
| csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | |||||
| csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | |||||
| cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}" | |||||
| # Mark the structure has been created. | |||||
| self._created_module[class_key.lower()] = class_name | |||||
| return f"{cls_definition}" \ | |||||
| f"{class_init}" \ | |||||
| f"{init_body}{NEW_LINE}" \ | |||||
| f"{class_construct}" \ | |||||
| f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}" | |||||
| def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx): | |||||
| """ | |||||
| Generate statements. | |||||
| Args: | |||||
| cur_nd_inst (Node): Current node instance. | |||||
| pre_nd_inst (Node): Precursor node instance. | |||||
| idx (int): Index of cur node. | |||||
| Returns: | |||||
| Tuple[str, str], declare in init and call in construct. | |||||
| """ | |||||
| ipt_args_in_construct = "x" | |||||
| opt_arg_in_construct = ["output"] | |||||
| if idx != 0: | |||||
| if cur_nd_inst.data.is_in_multi_opt_graph: | |||||
| ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst) | |||||
| else: | |||||
| # Get previous node output variable name. | |||||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||||
| # Set opt variable name. | |||||
| if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph: | |||||
| opt_arg_in_construct = [ | |||||
| f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||||
| ] | |||||
| else: | |||||
| opt_arg_in_construct = [ | |||||
| f"opt_{var_name}" | |||||
| for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name | |||||
| ] | |||||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | |||||
| variable_name=self.code_fragment_recorder[ | |||||
| cur_nd_inst.identifier].declared_var_name, | |||||
| output_var=opt_arg_in_construct, | |||||
| code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) | |||||
| return declare, call | |||||
| @staticmethod | |||||
| def _get_var_name(s): | |||||
| """ | |||||
| Get variable name using scope name. | |||||
| Args: | |||||
| s (str): String. | |||||
| Returns: | |||||
| str, variable name. | |||||
| """ | |||||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||||
| def _get_current_ipt_var(self, cur_nd): | |||||
| """" | |||||
| Get current input variable name from node_id. | |||||
| Args: | |||||
| cur_nd (Node): Current node. | |||||
| Returns: | |||||
| str, needed var names. | |||||
| """ | |||||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||||
| while True: | |||||
| p_nd = cur_nd.successors(self.tree_identifier) | |||||
| if not p_nd: | |||||
| break | |||||
| cur_nd = self.get_node(p_nd[0]) | |||||
| ipt_lst_raw = [] | |||||
| for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs: | |||||
| ipt_lst_raw.append(f"{operation_input}") | |||||
| opt_var_names_p_nds = set() | |||||
| for e in cur_nd.data.precursor_nodes: | |||||
| p_nd = self.get_node(e) | |||||
| if p_nd.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| opt_var_names_p_nd = set(p_nd.data.opt_var_names) | |||||
| opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd) | |||||
| ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)] | |||||
| return ", ".join(ipt_lst) | |||||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | |||||
| """ | |||||
| Find all input variable names. | |||||
| Args: | |||||
| cur_nd (Node): Current node. | |||||
| pre_nd (Node): Precursor node. | |||||
| Returns: | |||||
| list, needed var names list. | |||||
| """ | |||||
| ipt_lst = [] | |||||
| if cur_nd.tag in NO_CONVERTED_OPERATORS: | |||||
| return ipt_lst | |||||
| for e in cur_nd.data.precursor_nodes: | |||||
| p_nd = self.get_node(e) | |||||
| if e not in pre_nd.successors(self.tree_identifier): | |||||
| while True: | |||||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||||
| ipt_lst.append( | |||||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||||
| ) | |||||
| break | |||||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||||
| if not pre_nd_name: | |||||
| ipt_lst.append("x") | |||||
| break | |||||
| p_nd = self.get_node(pre_nd_name) | |||||
| continue | |||||
| ipt_lst.append( | |||||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||||
| ) | |||||
| return ipt_lst | |||||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||||
| """ | |||||
| Get needed input variable names. | |||||
| Args: | |||||
| cur_nd (Node): Current node. | |||||
| pre_nd (Node): Precursor node. | |||||
| Returns: | |||||
| str, needed var names. | |||||
| """ | |||||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||||
| while True: | |||||
| p_nd = cur_nd.successors(self.tree_identifier) | |||||
| if not p_nd: | |||||
| break | |||||
| cur_nd = self.get_node(p_nd[0]) | |||||
| return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | |||||
| def hash_key(self, node): | |||||
| """ | |||||
| Generate hash key for each node. | |||||
| Args: | |||||
| node (Node): Node. | |||||
| Returns: | |||||
| str, hash key. | |||||
| """ | |||||
| scsr_topo_order = [] | |||||
| for s in node.successors(self.tree_identifier): | |||||
| cur_nd = self.get_node(s) | |||||
| if cur_nd.data.node_type in {NodeType.MODULE.value, | |||||
| NodeType.FUNC.value, | |||||
| NodeType.CLASS.value}: | |||||
| if cur_nd.data.hash_key: | |||||
| scsr_topo_order.append(f"({cur_nd.data.hash_key})") | |||||
| continue | |||||
| raise ValueError("Current node doesn't have hash key.") | |||||
| if cur_nd.data.hash_key: | |||||
| scsr_topo_order.append(cur_nd.data.hash_key) | |||||
| continue | |||||
| unique_key = "->".join(scsr_topo_order) | |||||
| node.data.hash_key = unique_key | |||||
| return unique_key | |||||
| def _module_merging(self): | |||||
| """Generate sub-module and corresponding params.""" | |||||
| merged_module_args = dict() | |||||
| for module_key, module_args in self._merged_module.items(): | |||||
| if module_key not in merged_module_args: | |||||
| merged_module_args[module_key] = [] | |||||
| # Take first element's args as base. | |||||
| keys = module_args[0].keys() | |||||
| for key in keys: | |||||
| for i in range(1, len(module_args)): | |||||
| if key in module_args[i] and module_args[0][key] != module_args[i][key]: | |||||
| merged_module_args[module_key].append(key) | |||||
| break | |||||
| if key not in module_args[i]: | |||||
| merged_module_args[module_key].append(key) | |||||
| break | |||||
| self._merged_module_args.update(merged_module_args) | |||||
| def _create_module_args_and_vars(self, node, mapper): | |||||
| """ | |||||
| Create module args and variables in current node. | |||||
| Args: | |||||
| node (Node): Node on tree. | |||||
| mapper (Mapper): Mapper of params. | |||||
| """ | |||||
| # All args and value pair in current node module. | |||||
| module_args = dict() | |||||
| module_key = self.hash_key(node) | |||||
| created = False | |||||
| if module_key not in self._vars_mgr_in_module: | |||||
| self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR | |||||
| self._module_vars[module_key] = [] | |||||
| else: | |||||
| created = True | |||||
| # Sub-modules in the module could have arg name conflicts. | |||||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||||
| nd_inst = self.get_node(successor_name) | |||||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| # Generation of params must behind variable assigment. | |||||
| if created: | |||||
| variable_name = self._module_vars[module_key][idx] | |||||
| else: | |||||
| variable_name = nd_inst.data.op_name or nd_inst.tag | |||||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | |||||
| code_fragment.declared_var_name = variable_name | |||||
| code_fragment.output_var_name = nd_inst.data.opt_var_names | |||||
| code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names) | |||||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | |||||
| module_args.update(nd_inst.data.args_in_code) | |||||
| if not created: | |||||
| self._module_vars[module_key].append(variable_name) | |||||
| node.data.args_in_code = module_args | |||||
| # Collect module args of `module_key`. | |||||
| if module_key not in self._merged_module: | |||||
| self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] | |||||
| else: | |||||
| self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) | |||||
| @staticmethod | |||||
| def _create_operation_args(node, mapper): | |||||
| """ | |||||
| Create operation args. | |||||
| Args: | |||||
| node (Node): Node on tree. | |||||
| mapper (Mapper): Mapper of params. | |||||
| """ | |||||
| node.data.param_transform(mapper) | |||||
| def update_hierarchical_order(self) -> NoReturn: | |||||
| """ | |||||
| Update hierarchical order. | |||||
| """ | |||||
| hierarchical_order = dict() | |||||
| queue = Queue() | |||||
| queue.put(item=(self.root, self.ROOT_LEVEL), block=False) | |||||
| while not queue.empty(): | |||||
| node_name, cur_level = queue.get(block=False) | |||||
| node_inst = self[node_name] | |||||
| if cur_level not in hierarchical_order: | |||||
| hierarchical_order[cur_level] = [] | |||||
| hierarchical_order[cur_level].append(node_name) | |||||
| for successor_name in node_inst.successors(self.tree_identifier): | |||||
| queue.put(item=(successor_name, cur_level + 1), block=False) | |||||
| self._hierarchical_order = hierarchical_order | |||||
| def sub_graph_merging(self) -> NoReturn: | |||||
| """Shrink the module has only one child.""" | |||||
| self.update_hierarchical_order() | |||||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||||
| for depth in depths: | |||||
| for node_name in self._hierarchical_order[depth]: | |||||
| node_inst = self[node_name] | |||||
| # If the node type is module and has only one child, | |||||
| # then merge it with its child. | |||||
| if node_inst.data.node_type == NodeType.MODULE.value and \ | |||||
| len(node_inst.successors(self.tree_identifier)) == 1: | |||||
| self.shrink(node_inst) | |||||
| def _adjust_structure(self) -> NoReturn: | |||||
| """Adjust tree structure to generate source code.""" | |||||
| self.sub_graph_merging() | |||||
| self.update_hierarchical_order() | |||||
| @@ -171,3 +171,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | outputs_list = [f"opt_{{{variable_slot}}}"] | ||||
| outputs_mapping = ((0, 0),) | outputs_mapping = ((0, 0),) | ||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| @staticmethod | |||||
| def _find_val_by_index(loc_index, values_dict): | |||||
| """Find value by location index of values_dict.""" | |||||
| result = None | |||||
| for idx, dict_val in enumerate(values_dict.values()): | |||||
| if idx == loc_index: | |||||
| result = dict_val | |||||
| break | |||||
| return result | |||||
| @@ -42,7 +42,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| """Convert params from PyTorch to MindSpore""" | """Convert params from PyTorch to MindSpore""" | ||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| params = kwargs['params'] | params = kwargs['params'] | ||||
| weight = weights['weight'].numpy() | |||||
| weight = weights['weight'] | |||||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | ||||
| if isinstance(params['dilations'], list): | if isinstance(params['dilations'], list): | ||||
| dilation = tuple(params['dilations']) | dilation = tuple(params['dilations']) | ||||
| @@ -130,7 +130,6 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| dim = len(kernel_size) | dim = len(kernel_size) | ||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| weight = weight.numpy() | |||||
| dim = weight.ndim - 2 | dim = weight.ndim - 2 | ||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| import numpy as np | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | ||||
| @@ -27,8 +28,11 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(**kwargs): | def _convert_params(**kwargs): | ||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| has_bias = bool('bias' in weights) | |||||
| weight = weights['weight'].numpy().transpose() | |||||
| weight_index = 0 | |||||
| bias_index = 1 | |||||
| bias = DenseMapper._find_val_by_index(bias_index, weights) | |||||
| has_bias = isinstance(bias, np.ndarray) | |||||
| weight = DenseMapper._find_val_by_index(weight_index, weights).transpose() | |||||
| in_channels, out_channels = weight.shape | in_channels, out_channels = weight.shape | ||||
| return { | return { | ||||
| 'in_channels': in_channels, | 'in_channels': in_channels, | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | ||||
| @@ -47,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| def _convert_params(**kwargs): | def _convert_params(**kwargs): | ||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| params = kwargs.get("params") | params = kwargs.get("params") | ||||
| mode = params.get('mode', 'constant') | |||||
| mode = convert_bytes_string_to_string(params.get('mode', 'constant')) | |||||
| pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() | pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() | ||||
| if mode == 'constant' and params.get('value') is None: | if mode == 'constant' and params.get('value') is None: | ||||
| if params.get('pads') or weights: | if params.get('pads') or weights: | ||||
| @@ -36,7 +36,7 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | transformed_params["kernel_size"] = tuple(params['kernel_shape']) | ||||
| transformed_params["stride"] = tuple(params['strides']) | transformed_params["stride"] = tuple(params['strides']) | ||||
| if "pads" in params: | if "pads" in params: | ||||
| if sum(params['pads']) == 0: | |||||
| if sum(params['pads']) == 0 and not params.get('ceil_mode', None): | |||||
| pad_mode = '\"valid\"' | pad_mode = '\"valid\"' | ||||
| else: | else: | ||||
| pad_mode = '\"same\"' | pad_mode = '\"same\"' | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -39,14 +39,9 @@ class GraphFactory: | |||||
| Returns: | Returns: | ||||
| Graph, graph instance. | Graph, graph instance. | ||||
| """ | """ | ||||
| if all([input_nodes, output_nodes]): | |||||
| onnx_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||||
| pytorch_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') | |||||
| pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') | |||||
| return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) | |||||
| onnx_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') | |||||
| onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') | |||||
| return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, | |||||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||||
| @@ -264,7 +264,7 @@ class Graph(BaseGraph, abc.ABC): | |||||
| Returns: | Returns: | ||||
| cls, graph instance. | cls, graph instance. | ||||
| """ | """ | ||||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | |||||
| src_graph = cls.load_graph(graph_path=model_path, sample_shape=sample_shape, **kwargs) | |||||
| ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None | ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None | ||||
| if ckpt is not None: | if ckpt is not None: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,12 +13,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define ONNX graph.""" | """Define ONNX graph.""" | ||||
| from importlib import import_module | |||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import Graph | from .base import Graph | ||||
| from .input_node import InputNode | from .input_node import InputNode | ||||
| from .onnx_graph_node import OnnxGraphNode | from .onnx_graph_node import OnnxGraphNode | ||||
| from .pytorch_graph_parser import PyTorchGraphParser | |||||
| from .tf_graph_parser import TFGraphParser | from .tf_graph_parser import TFGraphParser | ||||
| from .onnx_utils import OnnxDataLoader | from .onnx_utils import OnnxDataLoader | ||||
| @@ -151,7 +153,7 @@ class OnnxGraph(Graph): | |||||
| input_shape (tuple): Input shape. | input_shape (tuple): Input shape. | ||||
| """ | """ | ||||
| input_node = InputNode(input_shape) | input_node = InputNode(input_shape) | ||||
| input_node_name = self._raw_input_nodes.replace(":0", "") | |||||
| input_node_name = self._raw_input_nodes | |||||
| for node_name, node in self._nodes_collection.items(): | for node_name, node in self._nodes_collection.items(): | ||||
| if node_name in self._input_nodes: | if node_name in self._input_nodes: | ||||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | ipt_nd_name = input_node_name.format(input_node.scope_name) | ||||
| @@ -196,7 +198,18 @@ class OnnxGraph(Graph): | |||||
| """ | """ | ||||
| tf_input_nodes = kwargs.get('input_nodes') | tf_input_nodes = kwargs.get('input_nodes') | ||||
| tf_output_nodes = kwargs.get('output_nodes') | tf_output_nodes = kwargs.get('output_nodes') | ||||
| onnx_model = TFGraphParser.parse(graph_path, | |||||
| input_nodes=tf_input_nodes, | |||||
| output_nodes=tf_output_nodes) | |||||
| if graph_path.endswith('.pb'): | |||||
| onnx_model = TFGraphParser.parse(graph_path, | |||||
| input_nodes=tf_input_nodes, | |||||
| output_nodes=tf_output_nodes) | |||||
| elif graph_path.endswith('.onnx'): | |||||
| onnx = import_module('onnx') | |||||
| onnx_model = onnx.load(graph_path) | |||||
| optimizer = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer') | |||||
| onnx_simplify = getattr(optimizer, 'OnnxSimplify')() | |||||
| onnx_model = onnx_simplify.run_onnx_simplify(onnx_model, kwargs['sample_shape']) | |||||
| else: | |||||
| onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | |||||
| return onnx_model | return onnx_model | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -112,10 +112,10 @@ class OnnxTensor: | |||||
| def to_array(self): | def to_array(self): | ||||
| """Convert the tensor value from binary to np array.""" | """Convert the tensor value from binary to np array.""" | ||||
| onnx = import_module("onnx") | |||||
| numpy_helper = import_module("onnx.numpy_helper") | |||||
| # Convert binary data to np.array | # Convert binary data to np.array | ||||
| if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | ||||
| return onnx.numpy_helper.to_array(self.raw_tensor) | |||||
| return numpy_helper.to_array(self.raw_tensor) | |||||
| return self.raw_tensor | return self.raw_tensor | ||||
| @@ -383,15 +383,24 @@ class OnnxDataLoader: | |||||
| """Parse each onnx nodes in the model.""" | """Parse each onnx nodes in the model.""" | ||||
| nodes_topo_idx = [] | nodes_topo_idx = [] | ||||
| for idx, node in enumerate(self.nodes): | for idx, node in enumerate(self.nodes): | ||||
| if not node.name: | |||||
| node.name = "_".join(node.output) | |||||
| n = OnnxNode(node) | n = OnnxNode(node) | ||||
| self._nodes_dict[n.name] = n | self._nodes_dict[n.name] = n | ||||
| nodes_topo_idx.append((idx, n.name)) | nodes_topo_idx.append((idx, n.name)) | ||||
| if len(node.output) > 1: | if len(node.output) > 1: | ||||
| raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | ||||
| self.output_name_to_node_name[node.output[0]] = node.name | self.output_name_to_node_name[node.output[0]] = node.name | ||||
| for ipt_nd in node.input: | |||||
| if ipt_nd not in self.output_name_to_node_name: | |||||
| if self._global_context.onnx_node_inputs.get(n.name): | |||||
| self._global_context.onnx_node_inputs[n.name].append(ipt_nd) | |||||
| else: | |||||
| self._global_context.onnx_node_inputs[n.name] = [ipt_nd] | |||||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | ||||
| node_inputs = [i.replace(":0", "") for i in node.input] | |||||
| self._global_context.onnx_node_inputs[n.name] = node_inputs | |||||
| self._global_context.onnx_nodes_collection = self._nodes_dict | self._global_context.onnx_nodes_collection = self._nodes_dict | ||||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | self._global_context.onnx_nodes_topo_index = nodes_topo_idx | ||||
| @@ -449,7 +458,11 @@ class OnnxDataLoader: | |||||
| input_node = self.get_node(input_node_name) | input_node = self.get_node(input_node_name) | ||||
| node.precursor_onnx_node_dict[input_node_name] = input_node | node.precursor_onnx_node_dict[input_node_name] = input_node | ||||
| input_node.successor_onnx_node_dict[node_name] = node | input_node.successor_onnx_node_dict[node_name] = node | ||||
| continue | |||||
| if self._global_context.onnx_node_inputs.get(node.name): | |||||
| self._global_context.onnx_node_inputs[node.name].append(input_node_name) | |||||
| else: | |||||
| self._global_context.onnx_node_inputs[node.name] = [input_node_name] | |||||
| def initialize(self): | def initialize(self): | ||||
| """Initialize the OnnxDataLoader.""" | """Initialize the OnnxDataLoader.""" | ||||
| @@ -0,0 +1,144 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define ONNX optimizer operations.""" | |||||
| import copy | |||||
| from importlib import import_module | |||||
| import numpy as np | |||||
| from ..common.utils import fetch_output_from_onnx_model | |||||
| class OnnxSimplify: | |||||
| """To simplify onnx model.""" | |||||
| def __init__(self): | |||||
| self._onnx_model = None | |||||
| self._constant_nodes = list() | |||||
| self._outputs_infer = dict() | |||||
| def run_onnx_simplify(self, onnx_model, sample_shape): | |||||
| """ | |||||
| Run to simplify onnx model. | |||||
| Args: | |||||
| onnx_model (onnx.ModelProto): Onnx Model. | |||||
| sample_shape (tuple): Sample shape of input. | |||||
| """ | |||||
| self._onnx_model = onnx_model | |||||
| self._optimizer() | |||||
| self._get_constant_nodes() | |||||
| self._onnx_infer(sample_shape) | |||||
| self._replace_constant_nodes() | |||||
| self._optimizer() | |||||
| return self._onnx_model | |||||
| def _optimizer(self): | |||||
| """Run optimizer from onnx to eliminate constant nodes.""" | |||||
| onnxoptimizer = import_module('onnxoptimizer') | |||||
| optimizers_list = [ | |||||
| 'eliminate_deadend', | |||||
| 'eliminate_nop_dropout', | |||||
| 'eliminate_nop_cast', | |||||
| 'eliminate_nop_monotone_argmax', | |||||
| 'eliminate_nop_pad', | |||||
| 'extract_constant_to_initializer', | |||||
| 'eliminate_unused_initializer', | |||||
| 'eliminate_nop_transpose', | |||||
| 'eliminate_identity', | |||||
| 'fuse_add_bias_into_conv', | |||||
| 'fuse_consecutive_concats', | |||||
| 'fuse_consecutive_log_softmax', | |||||
| 'fuse_consecutive_reduce_unsqueeze', | |||||
| 'fuse_consecutive_squeezes', | |||||
| 'fuse_consecutive_transposes', | |||||
| 'fuse_matmul_add_bias_into_gemm', | |||||
| 'fuse_pad_into_conv', | |||||
| 'fuse_transpose_into_gemm' | |||||
| ] | |||||
| input_num = len(self._onnx_model.graph.input) | |||||
| onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) | |||||
| if self._onnx_model.ir_version > 3: | |||||
| del onnx_model_optimized.graph.input[input_num:] | |||||
| self._onnx_model = onnx_model_optimized | |||||
| def _get_constant_nodes(self): | |||||
| """Get constant nodes.""" | |||||
| const_nodes = list() | |||||
| const_tensors = [tensor_init.name for tensor_init in self._onnx_model.graph.initializer] | |||||
| const_tensors.append([node.output[0] | |||||
| for node in self._onnx_model.graph.node if node.op_type == 'Constant']) | |||||
| for node in self._onnx_model.graph.node: | |||||
| if node.op_type == 'Shape' or all([input_node in const_tensors for input_node in node.input]): | |||||
| const_nodes.append(node) | |||||
| const_tensors.extend(node.output) | |||||
| self._constant_nodes = copy.deepcopy(const_nodes) | |||||
| def _onnx_infer(self, infer_inputs_shape): | |||||
| """ | |||||
| Run onnx inference to get outputs of constant nodes. | |||||
| Args: | |||||
| infer_inputs_shape (tuple): Input shape for running inference. | |||||
| """ | |||||
| input_onnx = self._onnx_model.graph.input[0] | |||||
| input_onnx_name = input_onnx.name | |||||
| feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} | |||||
| output_nodes_name = list() | |||||
| for node in self._constant_nodes: | |||||
| output_nodes_name.extend(node.output) | |||||
| self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name) | |||||
| def _replace_constant_nodes(self): | |||||
| """Replace constant nodes to nodes with op_type 'Constant'.""" | |||||
| onnx = import_module('onnx') | |||||
| np_helper = import_module('onnx.numpy_helper') | |||||
| for i, node in enumerate(self._onnx_model.graph.node): | |||||
| if node in self._constant_nodes: | |||||
| for output in node.output: | |||||
| new_attr = onnx.helper.make_attribute( | |||||
| 'value', | |||||
| np_helper.from_array(self._outputs_infer[output], name=output) | |||||
| ) | |||||
| new_node = onnx.helper.make_node( | |||||
| op_type='Constant', | |||||
| inputs=list(), | |||||
| outputs=[output], | |||||
| name='_'.join(('node', output)) | |||||
| ) | |||||
| new_node.attribute.extend([new_attr]) | |||||
| self._insert_node(self._onnx_model.graph.node, i + 1, new_node) | |||||
| del self._onnx_model.graph.node[i] | |||||
| @staticmethod | |||||
| def _insert_node(repeated_container, index, node): | |||||
| """Insert node into onnx model.""" | |||||
| repeated_container.extend([repeated_container[-1]]) | |||||
| for i in reversed(range(index + 1, len(repeated_container) - 1)): | |||||
| repeated_container[i].CopyFrom(repeated_container[i - 1]) | |||||
| repeated_container[index].CopyFrom(node) | |||||
| @@ -1,691 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define PyTorch graph.""" | |||||
| import os | |||||
| import re | |||||
| import warnings | |||||
| from copy import deepcopy | |||||
| from importlib import import_module | |||||
| from typing import Dict, NoReturn | |||||
| import numpy as np | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import Graph | |||||
| from .input_node import InputNode | |||||
| from .pytorch_graph_node import PyTorchGraphNode | |||||
| from .pytorch_graph_parser import PyTorchGraphParser | |||||
| from .torch_utils import set_opset_version | |||||
| from ..common.utils import fetch_output_from_onnx_model | |||||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ | |||||
| MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME | |||||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||||
| from ...common.exceptions import ModelNotSupportError | |||||
| NONE_SCOPE_OP = { | |||||
| "onnx::Add": "Add", | |||||
| "onnx::Flatten": "Flatten", | |||||
| "onnx::Concat": "Concat", | |||||
| "onnx::Squeeze": "Squeeze", | |||||
| "onnx::Unsqueeze": "Unsqueeze", | |||||
| "onnx::Split": "Split", | |||||
| "onnx::Reshape": "Reshape", | |||||
| "onnx::Transpose": "Transpose", | |||||
| "onnx::Constant": "Constant", | |||||
| "onnx::ReduceMean": "ReduceMean", | |||||
| "onnx::Resize": "Resize", | |||||
| "onnx::Pad": "Pad" | |||||
| } | |||||
| CONSTANT_NODES_PATTERN = { | |||||
| "onnx::Resize": [ | |||||
| 'onnx::Concat', | |||||
| 'onnx::Slice', | |||||
| 'onnx::Cast', | |||||
| 'onnx::Concat', | |||||
| 'onnx::Unsqueeze', | |||||
| 'onnx::Floor', | |||||
| 'onnx::Mul', | |||||
| 'onnx::Cast', | |||||
| 'onnx::Gather', | |||||
| 'onnx::Shape' | |||||
| ], | |||||
| "onnx::Pad": [ | |||||
| 'onnx::Cast', | |||||
| 'onnx::Concat', | |||||
| 'onnx::ConstantOfShape', | |||||
| 'onnx::Sub', | |||||
| 'onnx::Mul', | |||||
| 'onnx::Div', | |||||
| 'onnx::Gather', | |||||
| 'onnx::Shape', | |||||
| 'onnx::Unsqueeze', | |||||
| 'onnx::Slice', | |||||
| 'onnx::Reshape', | |||||
| 'onnx::Transpose' | |||||
| ], | |||||
| "onnx::Constant": list() | |||||
| } | |||||
| def normalize_scope_name(node, scope_name_dict): | |||||
| """ | |||||
| Rename scope name into uniform. | |||||
| Args: | |||||
| node (Node): PyTorch node. | |||||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||||
| Returns: | |||||
| str, normalized scope name. | |||||
| """ | |||||
| global NONE_SCOPE_OP | |||||
| scope_name = node.scopeName() | |||||
| if not scope_name: | |||||
| name = [retrieve_scope_name(node, scope_name_dict)] | |||||
| else: | |||||
| name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||||
| scopes = [] | |||||
| for segment in name: | |||||
| segment = segment.split(LINK_IN_SCOPE)[0] | |||||
| left = segment.find(LEFT_BUCKET) | |||||
| right = segment.find(RIGHT_BUCKET) | |||||
| if left != -1: | |||||
| if segment[left + 1: right].isdigit(): | |||||
| scopes.append(f"{segment[:left]}_{segment[left + 1: right]}") | |||||
| else: | |||||
| scopes.append(segment[left + 1: right]) | |||||
| else: | |||||
| scopes.append(segment) | |||||
| if node.kind() in NONE_SCOPE_OP.keys(): | |||||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | |||||
| scopes = [s for s in scopes if s] | |||||
| node_id = PyTorchGraph.get_node_id(node) | |||||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}" | |||||
| def retrieve_scope_name(node, scope_name_dict): | |||||
| """ | |||||
| Retrieve scope name from input nodes. | |||||
| Args: | |||||
| node (Node): PyTorch node. | |||||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||||
| Return: | |||||
| str: Scope name. | |||||
| """ | |||||
| node_content = \ | |||||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]) | |||||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||||
| scope_name_ipt_nodes = list() | |||||
| for node_input in node_inputs: | |||||
| if not scope_name_dict.get(node_input, None): | |||||
| continue | |||||
| scope_name_ipt_nodes.append(scope_name_dict[node_input]) | |||||
| scope_name_split = list() | |||||
| for idx, _ in enumerate(scope_name_ipt_nodes): | |||||
| if not scope_name_split: | |||||
| scope_name_split = scope_name_ipt_nodes[idx] | |||||
| else: | |||||
| scope_name_split = [ | |||||
| sub_scope_name | |||||
| for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx] | |||||
| ] | |||||
| scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split) | |||||
| return scope_name | |||||
| class PyTorchGraph(Graph): | |||||
| """ | |||||
| Define PyTorch graph. | |||||
| Args: | |||||
| model (Module): PyTorch model. | |||||
| sample_shape (tuple): Input shape of the model. | |||||
| """ | |||||
| def __init__(self, model, sample_shape: tuple): | |||||
| super(PyTorchGraph, self).__init__(model=model) | |||||
| from .torch_utils import unique_state_dict | |||||
| self._params_dict = unique_state_dict(model) | |||||
| self._original_shape = list() | |||||
| self._nodes = list() | |||||
| self._constant_nodes = list() | |||||
| self._dynamic_nodes = list() | |||||
| self._has_eliminated_nodes = False | |||||
| self._file_graph_onnx = os.path.join( | |||||
| settings.WORKSPACE, 'log/mindconverter/' | |||||
| ) | |||||
| self.build(sample_shape) | |||||
| @staticmethod | |||||
| def _check_input_shape(input_shape): | |||||
| """ | |||||
| Check input shape. | |||||
| Args: | |||||
| input_shape (tuple): Input tensor shape. | |||||
| """ | |||||
| if not input_shape: | |||||
| err_msg = "`input_shape` can not be None." | |||||
| log.error(err_msg) | |||||
| raise ValueError(err_msg) | |||||
| for item in input_shape: | |||||
| if not isinstance(item, int): | |||||
| err_msg = "Only support model with one input now, " \ | |||||
| "and each shape value in `input_shape` should be int." | |||||
| log.error(err_msg) | |||||
| raise ValueError(err_msg) | |||||
| @staticmethod | |||||
| def _extract_shape(shape): | |||||
| """ | |||||
| Extract shape from string-type shape. | |||||
| Args: | |||||
| shape (str): Shape value in string-type. | |||||
| Returns: | |||||
| list, shape. | |||||
| """ | |||||
| if "," not in shape: | |||||
| return [] | |||||
| shape_arr = [] | |||||
| for s in shape.split(","): | |||||
| s = s.strip() | |||||
| if not s: | |||||
| return [] | |||||
| if ":" in s: | |||||
| s = s.split(":")[0] | |||||
| s = s.replace("!", "") | |||||
| if not s.isdigit(): | |||||
| return [] | |||||
| shape_arr.append(int(s)) | |||||
| return shape_arr | |||||
| def _trace_torch_graph(self, input_shape): | |||||
| """ | |||||
| Trace torch computational graph. | |||||
| Args: | |||||
| input_shape (tuple): Shape. | |||||
| Returns: | |||||
| object, pytorch graph. | |||||
| """ | |||||
| import torch | |||||
| from torch.onnx import OperatorExportTypes | |||||
| from .torch_utils import OverloadTorchModuleTemporarily | |||||
| from .torch_utils import create_autograd_variable | |||||
| from .torch_utils import onnx_tracer | |||||
| warnings.simplefilter("ignore") | |||||
| batched_sample = create_autograd_variable(torch.rand(*input_shape)) | |||||
| try: | |||||
| try: | |||||
| # Assign execution mode to eval. | |||||
| self.model.eval() | |||||
| with OverloadTorchModuleTemporarily() as _: | |||||
| # In pytorch higher version, trace function has a known. | |||||
| graph = onnx_tracer(self.model, batched_sample, | |||||
| OperatorExportTypes.ONNX) | |||||
| return graph | |||||
| except RuntimeError: | |||||
| # Assign execution mode to eval. | |||||
| self.model.eval() | |||||
| with OverloadTorchModuleTemporarily() as _: | |||||
| # In pytorch higher version, trace function has a known. | |||||
| set_opset_version(ONNX_OPSET_VERSION) | |||||
| graph = onnx_tracer(self.model, batched_sample, | |||||
| OperatorExportTypes.ONNX) | |||||
| return graph | |||||
| except RuntimeError as error: | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| def build(self, input_shape): | |||||
| """ | |||||
| Build graph tree. | |||||
| Args: | |||||
| input_shape (tuple): Input shape of model. | |||||
| """ | |||||
| self._check_input_shape(input_shape) | |||||
| self._original_shape = input_shape | |||||
| feed_forward_ipt_shape = tuple(input_shape) | |||||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||||
| nodes = list(graph.nodes()) | |||||
| self._nodes = nodes | |||||
| scope_name_dict = dict() | |||||
| self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes) | |||||
| for node in nodes: | |||||
| output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs())) | |||||
| if output_name in self._dynamic_nodes: | |||||
| continue | |||||
| node_name = normalize_scope_name(node, scope_name_dict) | |||||
| scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ | |||||
| = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) | |||||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | |||||
| output_shape_str = output_shape_str_list[1] | |||||
| output_shape = self._extract_shape(output_shape_str) | |||||
| weight_scope = '.'.join( | |||||
| re.findall(r'\[([\w\d.]+)]', node.scopeName()) | |||||
| ) | |||||
| if self._constant_nodes: | |||||
| node_weight = self._replace_constant_node(node) | |||||
| else: | |||||
| node_weight = {} | |||||
| for scope, weight in self._params_dict.items(): | |||||
| split_scope = scope.split('.') | |||||
| if '.'.join(split_scope[:-1]) == weight_scope: | |||||
| node_weight[split_scope[-1]] = weight | |||||
| if not node_weight and node.kind() == 'onnx::Conv': | |||||
| weight_names = list(self._params_dict.keys()) | |||||
| node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||||
| for node_input_name in node_input_names: | |||||
| if int(node_input_name) > len(weight_names): | |||||
| continue | |||||
| weight = self._params_dict[weight_names[int(node_input_name) - 1]] | |||||
| node_weight[weight_names[int(node_input_name) - 1]] = weight | |||||
| self._shape_dict[node_name] = output_shape | |||||
| self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight) | |||||
| self._nodes_record[node_name] = node_name | |||||
| for node_input in list(node.inputs()): | |||||
| if self._extract_input_name(node_input) in self._constant_nodes: | |||||
| continue | |||||
| # Connect input node and src node. | |||||
| nd_id = PyTorchGraph.get_node_id(node_input.node()) | |||||
| nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ | |||||
| node_input.node().scopeName() | |||||
| if nd_id and nd_scope_name: | |||||
| node_input_name = normalize_scope_name( | |||||
| node_input.node(), scope_name_dict | |||||
| ) | |||||
| self.build_connection(node_input_name, node_name) | |||||
| self._unmerge_multi_ipt_opt_script() | |||||
| super(PyTorchGraph, self).build(input_shape=input_shape) | |||||
| self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) | |||||
| @staticmethod | |||||
| def _extract_node_name(node): | |||||
| """Extract node name for node.""" | |||||
| result = re.match(r"\d+", str(node)) | |||||
| if result: | |||||
| return result.group(0) | |||||
| return None | |||||
| @staticmethod | |||||
| def _extract_input_name(node_input): | |||||
| """Extract node input name from node input.""" | |||||
| node_input_name = str(node_input).split('defined in')[0].strip() | |||||
| return node_input_name | |||||
| def _get_constant_nodes(self, nodes): | |||||
| """ | |||||
| Get constant nodes to be eliminated. | |||||
| Args: | |||||
| nodes (Nodes): Nodes in torch._C.Graph. | |||||
| Returns: | |||||
| Union(dict, list), output of constant_input_node_name and dynamic nodes name. | |||||
| """ | |||||
| constant_input_nodes = list() | |||||
| dynamic_nodes = list() | |||||
| for node in nodes: | |||||
| if node.kind() == 'onnx::Resize': | |||||
| self._has_eliminated_nodes = True | |||||
| constant_input_node, dynamic_node = self._generate_inputs_of(node) | |||||
| constant_input_nodes += constant_input_node | |||||
| dynamic_nodes += dynamic_node | |||||
| outputs = dict() | |||||
| if self._has_eliminated_nodes: | |||||
| torch = import_module('torch') | |||||
| device_target = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
| dump_input = torch.randn(*self._original_shape, device=device_target) | |||||
| temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx, | |||||
| '.graph_onnx.onnx')) | |||||
| symbolic_helper = import_module('torch.onnx.symbolic_helper') | |||||
| export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version') | |||||
| try: | |||||
| torch.onnx.export(self.model.to(device_target), dump_input, | |||||
| temp_onnx_path, opset_version=export_onnx_opset_version) | |||||
| outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape) | |||||
| finally: | |||||
| if os.path.exists(temp_onnx_path): | |||||
| os.remove(temp_onnx_path) | |||||
| return outputs, dynamic_nodes | |||||
| def _generate_inputs_of(self, node): | |||||
| """ | |||||
| Generate inputs of certain node. | |||||
| Args: | |||||
| node (Node): Node of torch._C.Graph. | |||||
| """ | |||||
| pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None) | |||||
| constant_input_nodes = list() | |||||
| dynamic_nodes = list() | |||||
| if not isinstance(pattern_op_lst, list): | |||||
| return constant_input_nodes, dynamic_nodes | |||||
| if not pattern_op_lst: | |||||
| dynamic_nodes += self.get_node_id(node) | |||||
| return constant_input_nodes, dynamic_nodes | |||||
| node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||||
| for node_input_name in node_inputs_name: | |||||
| node_name_path = self._search_node_path(node_input_name, pattern_op_lst) | |||||
| if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape': | |||||
| constant_input_nodes.append(node_input_name) | |||||
| dynamic_nodes += node_name_path | |||||
| return constant_input_nodes, dynamic_nodes | |||||
| def _search_node_path(self, node_name, pattern_op_lst): | |||||
| """ | |||||
| Search node path based on pattern_op_list. | |||||
| Args: | |||||
| node_name (str): Node name. | |||||
| pattern_op_lst (list): Pattern list of certain operator. | |||||
| Returns: | |||||
| list[str]: node names in pattern. | |||||
| """ | |||||
| node_type_lst = list() | |||||
| node_name_lst = list() | |||||
| node = self._get_node_from_graph(node_name) | |||||
| if node_name == MODEL_INPUT_NAME: | |||||
| return node_name_lst | |||||
| if node.kind() not in pattern_op_lst: | |||||
| return node_name_lst | |||||
| node_type_lst.append(node.kind()) | |||||
| node_name_lst.append(node_name) | |||||
| node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] | |||||
| for node_input_name in node_inputs_name: | |||||
| node_name_lst += self._search_node_path(node_input_name, pattern_op_lst) | |||||
| return node_name_lst | |||||
| def _get_node_from_graph(self, node_name): | |||||
| """Get torch._C.Node from torch._C.Graph.""" | |||||
| for idx, node in enumerate(self._nodes): | |||||
| node_id = ', '.join(self.get_node_id(node)) | |||||
| if node_id == node_name: | |||||
| return self._nodes[idx] | |||||
| return None | |||||
| @staticmethod | |||||
| def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape): | |||||
| """ | |||||
| Infer onnx model to get outputs of inner nodes. | |||||
| Args: | |||||
| file_graph_onnx (str): File path of onnx. | |||||
| infer_outputs (list): Outputs for infer. | |||||
| infer_inputs_shape (list): Input shape for infer. | |||||
| """ | |||||
| onnx = import_module('onnx') | |||||
| tensor_proto = getattr(onnx, 'TensorProto') | |||||
| onnx_model = onnx.load(file_graph_onnx) | |||||
| for onnx_node in onnx_model.graph.node: | |||||
| if set(onnx_node.output).issubset(set(infer_outputs)): | |||||
| onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output]) | |||||
| input_onnx = onnx_model.graph.input[0] | |||||
| node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type) | |||||
| if node_type != 'FLOAT': | |||||
| raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. " | |||||
| f"Please report issue to us if extra input type is needed.") | |||||
| input_onnx_name = input_onnx.name | |||||
| feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} | |||||
| outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs) | |||||
| return outputs | |||||
| def _replace_constant_node(self, node): | |||||
| """Replace constant node.""" | |||||
| node_weight = dict() | |||||
| for node_input in list(node.inputs()): | |||||
| node_input_name = self._extract_input_name(node_input) | |||||
| if node_input_name in self._constant_nodes: | |||||
| node_weight[node_input_name] = self._constant_nodes[node_input_name] | |||||
| return node_weight | |||||
| def _collect_ipt_shape_of_each_node(self, input_shape): | |||||
| """ | |||||
| Collect input tensor shape of each node. | |||||
| Args: | |||||
| input_shape (tuple): Input shape. | |||||
| """ | |||||
| input_node = InputNode(input_shape) | |||||
| input_node_name = "{}InputNode" | |||||
| for node_name, node in self._nodes_collection.items(): | |||||
| if node_name in self._input_nodes: | |||||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||||
| input_node.set_scope_name(node.scope_name) | |||||
| node.precursor_nodes.insert(0, ipt_nd_name) | |||||
| input_node.set_successor_nodes(node_name) | |||||
| self._shape_dict[ipt_nd_name] = input_node.output_shape | |||||
| if not self._shape_dict[node_name]: | |||||
| self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE | |||||
| ipt_shape = [] | |||||
| for p_nd in node.precursor_nodes: | |||||
| shp = self._shape_dict.get(p_nd) | |||||
| ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) | |||||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | |||||
| def _generate_module(self): | |||||
| """Generate modules.""" | |||||
| module_dict = dict() | |||||
| for node_key, _ in self._nodes_collection.items(): | |||||
| node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE) | |||||
| if len(node_key_in_scope) < MIN_SCOPE_LENGTH: | |||||
| continue | |||||
| for idx in range(1, len(node_key_in_scope)): | |||||
| node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx]) | |||||
| node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1]) | |||||
| if not module_dict.get(node_key_module, None): | |||||
| module_dict[node_key_module] = {node_name} | |||||
| else: | |||||
| module_dict[node_key_module].add(node_name) | |||||
| return module_dict | |||||
| def _check_multi_ipt_opt(self): | |||||
| """Check whether multi-input exists.""" | |||||
| module_dict = self._generate_module() | |||||
| for _, nodes_per_module in module_dict.items(): | |||||
| prcs_nodes_out_from_module = set() | |||||
| for node_name in nodes_per_module: | |||||
| if re.search(r"[\d]+[&][\d]+", node_name): | |||||
| self._is_multi_opt_graph = True | |||||
| return True | |||||
| node = self._nodes_collection.get(node_name, None) | |||||
| if node: | |||||
| prcs_nodes = node.precursor_nodes | |||||
| else: | |||||
| continue | |||||
| for prcs_node in prcs_nodes: | |||||
| if prcs_node not in nodes_per_module: | |||||
| prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1]) | |||||
| if prcs_node_module not in nodes_per_module: | |||||
| prcs_nodes_out_from_module.add(prcs_node) | |||||
| if len(prcs_nodes_out_from_module) > 1: | |||||
| return True | |||||
| return False | |||||
| def _unmerge_multi_ipt_opt_script(self): | |||||
| """Unmerge all submodule.""" | |||||
| if self._check_multi_ipt_opt() or self._has_eliminated_nodes: | |||||
| for node_key, node_inst in deepcopy(self._nodes_collection).items(): | |||||
| prsc_nodes = node_inst.precursor_nodes | |||||
| scsr_nodes = node_inst.successor_nodes | |||||
| node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph | |||||
| node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], | |||||
| prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) | |||||
| for prsc_node in deepcopy(prsc_nodes)] | |||||
| node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0], | |||||
| scsr_node.split(SEPARATOR_IN_SCOPE)[-1])) | |||||
| for scsr_node in deepcopy(scsr_nodes)] | |||||
| reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], | |||||
| node_key.split(SEPARATOR_IN_SCOPE)[-1])) | |||||
| del self._nodes_collection[node_key] | |||||
| self._nodes_collection[reduce_node_key] = node_inst | |||||
| for node_key, shape in deepcopy(self._shape_dict).items(): | |||||
| reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], | |||||
| node_key.split(SEPARATOR_IN_SCOPE)[-1])) | |||||
| del self._shape_dict[node_key] | |||||
| self._shape_dict[reduce_node_key] = shape | |||||
| def sub_graph_merging(self): | |||||
| """ | |||||
| Merge split operation into one. | |||||
| """ | |||||
| raise NotImplementedError() | |||||
| def build_connection(self, src, tgt) -> NoReturn: | |||||
| """ | |||||
| Build connection between source node and target node. | |||||
| Args: | |||||
| src (str): Source node name. | |||||
| tgt (str): Target node name. | |||||
| """ | |||||
| # If src and tgt are the same node, src not in node_collection or | |||||
| # tgt not in node_collection, then skip this edge. | |||||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | |||||
| if src.split(':')[0] not in self._nodes_collection: | |||||
| log.warning("Graph construct a self-loop node %s. Ignored.", src) | |||||
| return | |||||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | |||||
| self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | |||||
| if src not in self._nodes_collection[tgt].precursor_nodes: | |||||
| self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src) | |||||
| @staticmethod | |||||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||||
| """ | |||||
| Load checkpoint. | |||||
| Args: | |||||
| ckpt_path (str): Checkpoint file path. | |||||
| Returns: | |||||
| dict, weights in model. | |||||
| """ | |||||
| @staticmethod | |||||
| def load_metadata(**kwargs): | |||||
| """ | |||||
| Load graph metadata. | |||||
| """ | |||||
| err_msg = "class `PyTorchGraph` has not implemented " \ | |||||
| "`load_metadata()`." | |||||
| log.error(err_msg) | |||||
| raise NotImplementedError(err_msg) | |||||
| @staticmethod | |||||
| def load_graph(graph_path: str, **kwargs): | |||||
| """ | |||||
| Load graph. | |||||
| Args: | |||||
| graph_path (str): Graph path. | |||||
| Returns: | |||||
| object, pytorch model. | |||||
| """ | |||||
| torch_model = PyTorchGraphParser.parse(graph_path) | |||||
| return torch_model | |||||
| @staticmethod | |||||
| def get_node_id(node): | |||||
| """ | |||||
| Get node id using regular expr. | |||||
| Args: | |||||
| node (Node): PyTorch node. | |||||
| Returns: | |||||
| str, node id. | |||||
| """ | |||||
| node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||||
| node_id = re.findall(r"[%](.*?) [:]", node_title) | |||||
| return node_id | |||||
| @@ -1,236 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define PyTorch graph node.""" | |||||
| import re | |||||
| from .base import GraphNode | |||||
| from ..common.utils import is_converted | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||||
| SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||||
| class PyTorchGraphNode(GraphNode): | |||||
| """ | |||||
| PyTorch graph node. | |||||
| Args: | |||||
| node (torch._C.Node): Node in raw PyTorch graph. | |||||
| """ | |||||
| _type_frozen = False | |||||
| _module_name_frozen = False | |||||
| def __init__(self, node=None, weight=None): | |||||
| super(PyTorchGraphNode, self).__init__(node=node) | |||||
| self._op_params = self._get_raw_params(node) | |||||
| self._op_name = node.kind() if node else None | |||||
| self._scope_name = node.scopeName() if node else None | |||||
| self._weight = weight | |||||
| self._ipt_var_names, self._opt_var_names \ | |||||
| = self._extract_ipt_opt_var_names() if node else (list(), list()) | |||||
| def _extract_ipt_opt_var_names(self): | |||||
| """Extract ipt and opt var names.""" | |||||
| node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join( | |||||
| str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:] | |||||
| ) | |||||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||||
| node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||||
| node_outputs = re.findall(r"[%](.*?) [:]", node_title) | |||||
| return node_inputs, node_outputs | |||||
| def clear_args_of_declaration(self): | |||||
| """ | |||||
| Clear `self._args_in_code`. | |||||
| """ | |||||
| self._args_in_code = dict() | |||||
| def _get_arg_name(self, arg, variable_name): | |||||
| """ | |||||
| Get arg name. | |||||
| Args: | |||||
| arg (str): Generate arg name. | |||||
| Returns: | |||||
| str, arg name in function or class declaration. | |||||
| """ | |||||
| return f"{arg}_{variable_name}" | |||||
| @property | |||||
| def is_in_multi_opt_graph(self): | |||||
| return self._is_in_multi_opt_graph | |||||
| @is_in_multi_opt_graph.setter | |||||
| def is_in_multi_opt_graph(self, multi_opt_state): | |||||
| self._is_in_multi_opt_graph = multi_opt_state | |||||
| @property | |||||
| def hash_key(self): | |||||
| """ | |||||
| Return unique hash key of current node. | |||||
| Returns: | |||||
| str, hash key. | |||||
| """ | |||||
| if self._node_type not in {NodeType.CLASS.value, | |||||
| NodeType.FUNC.value, | |||||
| NodeType.MODULE.value}: | |||||
| self._hash_key = self._op_name.lower() | |||||
| return self._hash_key | |||||
| @hash_key.setter | |||||
| def hash_key(self, h): | |||||
| """ | |||||
| Setter of hash key. | |||||
| Args: | |||||
| h (str): Key. | |||||
| """ | |||||
| self._hash_key = h | |||||
| @property | |||||
| def op_name(self): | |||||
| """ | |||||
| Op name in torch. | |||||
| Returns: | |||||
| str, op name. | |||||
| """ | |||||
| return self._op_name | |||||
| @op_name.setter | |||||
| def op_name(self, name): | |||||
| """ | |||||
| Setter of op name. | |||||
| Args: | |||||
| name(str): op_name. | |||||
| """ | |||||
| self._op_name = name | |||||
| @property | |||||
| def real_name(self): | |||||
| return | |||||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||||
| """ | |||||
| Add the node input shape. | |||||
| Args: | |||||
| output_shape (tuple): Output tensor shape. | |||||
| input_shape (tuple): Input tensor shape. | |||||
| """ | |||||
| self._ipt_shape = input_shape | |||||
| self._opt_shape = output_shape | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment): | |||||
| """ | |||||
| Generate statements. | |||||
| Args: | |||||
| variable_name (str): Variable name. | |||||
| ipt_args_in_construct (str): Args of input. | |||||
| output_var (list): Output variable names in construct. | |||||
| code_fragment (CodeFragment): CodeFragment instance. | |||||
| Returns: | |||||
| Union[str, str], declare in init and call in construct. | |||||
| """ | |||||
| operator = code_fragment.operation | |||||
| args = self.args_in_code | |||||
| settings = code_fragment.code_setting | |||||
| if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): | |||||
| args.update({"input_shape": self.input_shape, | |||||
| "output_shape": self.output_shape}) | |||||
| if self._node_type == NodeType.OPERATION.value: | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( | |||||
| ipt_args_in_construct, settings) | |||||
| else: | |||||
| # When it's type is module, class or func, | |||||
| # it's not necessary to replace var. | |||||
| expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||||
| if SEPARATOR_IN_ONNX_OP in operator: | |||||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||||
| declare = f"self.{variable_name} = {operator}({expr})" | |||||
| call = f"{', '.join([output for output in output_var])}" \ | |||||
| f" = self.{variable_name}({ipt_args_settings_in_construct})" | |||||
| return declare, call | |||||
| def to_ir(self): | |||||
| """ | |||||
| No need to implement for now. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def _get_raw_params(self, node): | |||||
| """ | |||||
| Get params in onnx. | |||||
| Args: | |||||
| node (Any): Node. | |||||
| Returns: | |||||
| dict, raw params. | |||||
| """ | |||||
| from .torch_utils import getitem_of_node | |||||
| raw_params = dict() | |||||
| if not node: | |||||
| return raw_params | |||||
| for k in node.attributeNames(): | |||||
| raw_params[k] = getitem_of_node(node, k) | |||||
| return raw_params | |||||
| def replace_with_arg(self, src_arg, tgt_arg): | |||||
| """ | |||||
| Replace actual parameter with formal parameter. | |||||
| Args: | |||||
| src_arg (str): Original arg name. | |||||
| tgt_arg (str): Target arg name. | |||||
| """ | |||||
| self._args_in_code[src_arg] = tgt_arg | |||||
| @staticmethod | |||||
| def _extract_var_name(scope_name: str): | |||||
| """ | |||||
| Extract variable name from scope name. | |||||
| """ | |||||
| if not scope_name: | |||||
| return None | |||||
| var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower() | |||||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||||
| RIGHT_BUCKET, "") | |||||
| return var | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,6 +18,7 @@ from importlib import import_module | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | from .base import GraphParser | ||||
| from .optimizer import OnnxSimplify | |||||
| from ...common.exceptions import ModelNotSupportError | from ...common.exceptions import ModelNotSupportError | ||||
| @@ -38,7 +39,6 @@ class PyTorchGraphParser(GraphParser): | |||||
| Returns: | Returns: | ||||
| object, torch model. | object, torch model. | ||||
| """ | """ | ||||
| torch = import_module("torch") | |||||
| if not os.path.exists(model_path): | if not os.path.exists(model_path): | ||||
| error = FileNotFoundError("`model_path` must be assigned with " | error = FileNotFoundError("`model_path` must be assigned with " | ||||
| @@ -47,14 +47,66 @@ class PyTorchGraphParser(GraphParser): | |||||
| raise error | raise error | ||||
| try: | try: | ||||
| if torch.cuda.is_available(): | |||||
| model = torch.load(f=model_path) | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| onnx_model_sim = cls._convert_pytorch_graph_to_onnx( | |||||
| model_path, kwargs['sample_shape'], opset_version=11) | |||||
| return onnx_model_sim | |||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| error_msg = "Cannot find model scripts in system path, " \ | error_msg = "Cannot find model scripts in system path, " \ | ||||
| "set `--project_path` to the path of model scripts folder correctly." | "set `--project_path` to the path of model scripts folder correctly." | ||||
| error = ModuleNotFoundError(error_msg) | error = ModuleNotFoundError(error_msg) | ||||
| raise error | raise error | ||||
| return model | |||||
| @staticmethod | |||||
| def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None): | |||||
| """ | |||||
| Convert Pytorch model to ONNX model. | |||||
| Args: | |||||
| model_path (str): Path to the Pytorch model. | |||||
| sample_shape (tuple): Input shape to generate onnx model. | |||||
| opset_version (int): Op set version of onnx. | |||||
| """ | |||||
| torch = import_module('torch') | |||||
| has_cuda = torch.cuda.is_available() | |||||
| if has_cuda: | |||||
| model = torch.load(f=model_path).cuda() | |||||
| dump_input = torch.randn(*sample_shape, device='cuda') | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| dump_input = torch.randn(*sample_shape, device='cpu') | |||||
| if isinstance(model, torch.nn.DataParallel): | |||||
| raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.') | |||||
| torch_onnx = import_module('torch.onnx') | |||||
| operator_export_types = getattr(torch_onnx, 'OperatorExportTypes') | |||||
| utils = import_module('torch.onnx.utils') | |||||
| model_to_graph = getattr(utils, '_model_to_graph') | |||||
| symbolic_helper = import_module('torch.onnx.symbolic_helper') | |||||
| default_onnx_opset_version = getattr(symbolic_helper, '_default_onnx_opset_version') | |||||
| set_opset_version = getattr(symbolic_helper, '_set_opset_version') | |||||
| set_operator_export_type = getattr(symbolic_helper, '_set_operator_export_type') | |||||
| if not opset_version: | |||||
| opset_version = default_onnx_opset_version | |||||
| operator_export_type = operator_export_types.ONNX | |||||
| set_opset_version(opset_version) | |||||
| set_operator_export_type(operator_export_type) | |||||
| graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True) | |||||
| export_onnx = getattr(graph, '_export_onnx') | |||||
| proto, _ = export_onnx( | |||||
| params_dict, opset_version, dict(), False, | |||||
| operator_export_type, True, False, dict(), | |||||
| True, False) | |||||
| onnx = import_module('onnx') | |||||
| onnx_model = onnx.load_model_from_string(proto) | |||||
| onnx_simplify = OnnxSimplify() | |||||
| onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape) | |||||
| return onnx_model_sim | |||||
| @@ -1,105 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define pytorch tracer context manager.""" | |||||
| import importlib | |||||
| from torch.nn import Module | |||||
| from torch.onnx.utils import _trace | |||||
| from torch.onnx.utils import _node_getitem | |||||
| from torch.onnx.symbolic_helper import _set_opset_version | |||||
| SCRIPT_METHOD = getattr(importlib.import_module("torch._C"), | |||||
| "ScriptMethod") | |||||
| onnx_tracer = _trace | |||||
| getitem_of_node = _node_getitem | |||||
| set_opset_version = _set_opset_version | |||||
| def unique_state_dict(model): | |||||
| """ | |||||
| Wrapper of torch.jit._unique_state_dict. | |||||
| Args: | |||||
| model (Module): Torch model. | |||||
| Returns: | |||||
| dict, params. | |||||
| """ | |||||
| from torch.jit import _unique_state_dict | |||||
| return _unique_state_dict(model) | |||||
| def create_autograd_variable(tensor): | |||||
| """ | |||||
| Create autograd variable to trace the whole graph. | |||||
| Args: | |||||
| tensor (torch.Tensor): Tensor. | |||||
| Returns: | |||||
| torch.autograd.Variable, variable. | |||||
| """ | |||||
| variable = getattr(importlib.import_module("torch.autograd"), "Variable") | |||||
| return variable(tensor, requires_grad=False) | |||||
| class OverloadTorchModuleTemporarily: | |||||
| """ | |||||
| Fix bugs in new version of pytorch. | |||||
| PyTorch official solution. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.backup = None | |||||
| def __enter__(self): | |||||
| def _tracing_name(traced_module, tracing_state): | |||||
| traced_module_stack = getattr(tracing_state, "_traced_module_stack") | |||||
| if not traced_module_stack: | |||||
| return None | |||||
| module = traced_module_stack[-1] | |||||
| for name, child in module.named_children(): | |||||
| if child is traced_module: | |||||
| return name | |||||
| return None | |||||
| def _slow_forward(self_, *inputs, **kwargs): | |||||
| tracing_state = getattr(importlib.import_module("torch._C"), | |||||
| "_get_tracing_state")() | |||||
| if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD): | |||||
| return self_.forward(*inputs, **kwargs) | |||||
| if not hasattr(tracing_state, '_traced_module_stack'): | |||||
| tracing_state._traced_module_stack = [] | |||||
| name = _tracing_name(self_, tracing_state) | |||||
| get_name_func = getattr(self_, "_get_name") | |||||
| if name: | |||||
| tracing_state.push_scope('%s[%s]' % (get_name_func(), name)) | |||||
| else: | |||||
| tracing_state.push_scope(get_name_func()) | |||||
| tracing_state._traced_module_stack.append(self_) | |||||
| try: | |||||
| result = self_.forward(*inputs, **kwargs) | |||||
| finally: | |||||
| tracing_state.pop_scope() | |||||
| tracing_state._traced_module_stack.pop() | |||||
| return result | |||||
| self.backup = getattr(Module, "_slow_forward") | |||||
| setattr(Module, '_slow_forward', _slow_forward) | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||||
| setattr(Module, '_slow_forward', self.backup) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Test name manager module.""" | """Test name manager module.""" | ||||
| from unittest import TestCase | from unittest import TestCase | ||||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \ | |||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import NameMgr, GlobalVarNameMgr, \ | |||||
| global_op_namespace | global_op_namespace | ||||
| @@ -1,15 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Unit test for mindconvert.graph_based_converter.hierarchical_tree interface.""" | |||||
| @@ -1,177 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test hierarchical tree module.""" | |||||
| import os | |||||
| import shutil | |||||
| from unittest import mock | |||||
| import pytest | |||||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType | |||||
| from tests.ut.mindconverter.graph_based_converter.conftest import TEST_BASE_PATH | |||||
| class TestHierarchicalTree: | |||||
| """Test the class of HierarchicalTree.""" | |||||
| def test_tree_identifier(self): | |||||
| """Test tree_identifier""" | |||||
| tree = HierarchicalTree() | |||||
| assert isinstance(tree.tree_identifier, str) | |||||
| @mock.patch( | |||||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||||
| def test_insert(self, get_raw_params): | |||||
| """Test insert""" | |||||
| get_raw_params.return_value = [] | |||||
| tree = HierarchicalTree() | |||||
| pt_node = PyTorchGraphNode() | |||||
| tree.insert(pt_node, 'ResNet') | |||||
| assert tree.root == 'ResNet' | |||||
| def test_remove(self): | |||||
| """Test remove function.""" | |||||
| tree = HierarchicalTree() | |||||
| tree.create_node( | |||||
| tag='node_root', | |||||
| identifier='root', | |||||
| parent=None, | |||||
| data=None | |||||
| ) | |||||
| node = tree.get_node('root') | |||||
| tree.remove(node) | |||||
| assert tree.root is None | |||||
| @mock.patch( | |||||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||||
| def test_shrink(self, get_raw_params): | |||||
| """Test shrink function.""" | |||||
| params = {'root': {}, | |||||
| 'root/child0': {}, | |||||
| 'root/child0/child1': {}} | |||||
| tree = self._create_tree(get_raw_params=get_raw_params, params=params) | |||||
| node = tree.get_node('root/child0') | |||||
| tree.shrink(node) | |||||
| assert tree.leaves()[0].tag == 'child1' | |||||
| @pytest.mark.parametrize('params', [{ | |||||
| 'tree_params': {'root': {'op_name': 'Root', | |||||
| 'precursor_nodes': [], | |||||
| 'successor_nodes': ['root/relu'], | |||||
| 'node_type': NodeType.MODULE.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 1, 224, 224]}, | |||||
| 'root/relu': {'op_name': 'onnx::Relu', | |||||
| 'precursor_nodes': ['root'], | |||||
| 'successor_nodes': ['root/unknown'], | |||||
| 'node_type': NodeType.OPERATION.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 3, 224, 224]}, | |||||
| 'root/unknown': {'op_name': 'onnx::Unknown', | |||||
| 'precursor_nodes': ['root/relu'], | |||||
| 'successor_nodes': [], | |||||
| 'node_type': NodeType.OPERATION.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 1, 224, 224]}}, | |||||
| 'report_dir': 'report_folder' | |||||
| }, { | |||||
| 'tree_params': {'root': {'op_name': 'Root', | |||||
| 'precursor_nodes': [], | |||||
| 'successor_nodes': ['root/relu'], | |||||
| 'node_type': NodeType.MODULE.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 1, 224, 224]}, | |||||
| 'root/relu': {'op_name': 'onnx::Relu', | |||||
| 'precursor_nodes': ['root'], | |||||
| 'successor_nodes': ['root/unknown'], | |||||
| 'node_type': NodeType.OPERATION.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 3, 224, 224]}, | |||||
| 'root/unknown': {'op_name': 'onnx::Unknown', | |||||
| 'precursor_nodes': ['root/relu'], | |||||
| 'successor_nodes': [], | |||||
| 'node_type': NodeType.OPERATION.value, | |||||
| 'input_shape': [1, 3, 224, 224], | |||||
| 'output_shape': [1, 1, 224, 224]}}, | |||||
| 'report_dir': None | |||||
| }]) | |||||
| @mock.patch( | |||||
| '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) | |||||
| def test_save_source_file(self, get_raw_params, params): | |||||
| """Test save_source_file function.""" | |||||
| tree_params = params['tree_params'] | |||||
| out_folder = 'out_folder' | |||||
| report_folder = params['report_dir'] | |||||
| model_name = 'model_name' | |||||
| mapper = ONNXToMindSporeMapper() | |||||
| tree = self._create_tree(get_raw_params=get_raw_params, params=tree_params) | |||||
| tree.save_source_files(out_folder, mapper, model_name, report_folder) | |||||
| out_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) | |||||
| report_folder_test = report_folder if report_folder else out_folder | |||||
| report_path = os.path.realpath( | |||||
| os.path.join(report_folder_test, f"report_of_{model_name}.txt")) | |||||
| try: | |||||
| assert os.path.exists(out_path) | |||||
| assert os.path.exists(report_path) | |||||
| with open(out_path, 'r') as out_r: | |||||
| code = out_r.read() | |||||
| assert 'nn.ReLU' in code | |||||
| assert 'onnx.Unknown' in code | |||||
| with open(report_path, 'r') as report_r: | |||||
| report = report_r.read() | |||||
| assert "[UnConvert] 'onnx::Unknown' didn't convert." in report | |||||
| assert "Converted Rate: 50.00%." in report | |||||
| finally: | |||||
| shutil.rmtree(out_folder) | |||||
| if report_folder: | |||||
| shutil.rmtree(report_folder) | |||||
| @staticmethod | |||||
| def _create_node(key, val, weight, input_shape, output_shape): | |||||
| """Create node.""" | |||||
| node = PyTorchGraphNode(weight=weight) | |||||
| node.add_input_and_output_shape(input_shape, output_shape) | |||||
| node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key | |||||
| node.op_name = val['op_name'] if val.get('op_name') else None | |||||
| node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] | |||||
| node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] | |||||
| node.node_type = val['node_type'] if val.get('node_type') else None | |||||
| return node | |||||
| @staticmethod | |||||
| def _create_tree(get_raw_params, params): | |||||
| """Create tree.""" | |||||
| tree = HierarchicalTree() | |||||
| for key, val in params.items(): | |||||
| input_shape = val['input_shape'] if val.get('input_shape') else [] | |||||
| output_shape = val['output_shape'] if val.get('output_shape') else [] | |||||
| get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() | |||||
| weight = val['weight'] if val.get('weight') else None | |||||
| node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape) | |||||
| tree.create_node( | |||||
| tag=node.tag, | |||||
| identifier=key, | |||||
| parent='/'.join(key.split('/')[:-1]) if len(key.split('/')) > 1 else None, | |||||
| data=node | |||||
| ) | |||||
| return tree | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,4 +12,4 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Unit test for mindconvert.graph_based_converter.mapper interface.""" | |||||
| """Unit test for mindconverter.graph_based_converter.mapper interface.""" | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,7 +18,6 @@ import pytest | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | ||||
| from tests.utils import mindspore | |||||
| class TestMappers: | class TestMappers: | ||||
| @@ -30,7 +29,7 @@ class TestMappers: | |||||
| 'group': 1, | 'group': 1, | ||||
| 'pads': [1, 2, 3, 4], | 'pads': [1, 2, 3, 4], | ||||
| 'strides': [1, 1]}, | 'strides': [1, 1]}, | ||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}}, | |||||
| 'weights': {'weight': np.zeros((64, 3, 1, 1), dtype=np.int32)}}, | |||||
| 'expected_output': {'converter_name': 'nn.Conv2d', | 'expected_output': {'converter_name': 'nn.Conv2d', | ||||
| 'converted_params': {'in_channels': 3, | 'converted_params': {'in_channels': 3, | ||||
| 'out_channels': 64, | 'out_channels': 64, | ||||
| @@ -47,7 +46,7 @@ class TestMappers: | |||||
| 'group': 1, | 'group': 1, | ||||
| 'pads': [0, 0, 0, 0], | 'pads': [0, 0, 0, 0], | ||||
| 'strides': [1, 1]}, | 'strides': [1, 1]}, | ||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}}, | |||||
| 'weights': {'weight': np.zeros((64, 3, 2, 2), dtype=np.int32)}}, | |||||
| 'expected_output': {'converter_name': 'nn.Conv2d', | 'expected_output': {'converter_name': 'nn.Conv2d', | ||||
| 'converted_params': {'in_channels': 3, | 'converted_params': {'in_channels': 3, | ||||
| 'out_channels': 64, | 'out_channels': 64, | ||||
| @@ -61,8 +60,8 @@ class TestMappers: | |||||
| }, { | }, { | ||||
| 'input': {'op_name': 'onnx::Gemm', | 'input': {'op_name': 'onnx::Gemm', | ||||
| 'params': dict(), | 'params': dict(), | ||||
| 'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)), | |||||
| 'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}}, | |||||
| 'weights': {'weight': np.zeros((10, 3), dtype=np.int32), | |||||
| 'bias': np.zeros((10, 1), dtype=np.int32)}}, | |||||
| 'expected_output': {'converter_name': 'nn.Dense', | 'expected_output': {'converter_name': 'nn.Dense', | ||||
| 'converted_params': {'in_channels': 3, | 'converted_params': {'in_channels': 3, | ||||
| 'out_channels': 10, | 'out_channels': 10, | ||||