From 2ce2c82bc91d8fc6cc6d1517730accdb889d1763 Mon Sep 17 00:00:00 2001 From: moran Date: Sat, 26 Dec 2020 17:11:30 +0800 Subject: [PATCH] Optimize pytorchTOms module --- .../{hierarchical_tree => common}/name_mgr.py | 2 +- .../graph_based_converter/common/utils.py | 3 + .../graph_based_converter/framework.py | 45 +- .../generator/args_translator.py | 7 +- .../generator/generator.py | 4 +- .../generator/module_struct.py | 4 +- .../generator/node_struct.py | 14 +- .../hierarchical_tree/__init__.py | 89 -- .../hierarchical_tree/hierarchical_tree.py | 796 ------------------ .../graph_based_converter/mapper/base.py | 10 + .../mapper/impl/nn/conv_mapper.py | 3 +- .../mapper/impl/nn/dense_mapper.py | 8 +- .../mapper/impl/nn/pad_mapper.py | 3 +- .../mapper/impl/nn/pool_mapper.py | 2 +- .../third_party_graph/__init__.py | 19 +- .../third_party_graph/base.py | 2 +- .../third_party_graph/onnx_graph.py | 23 +- .../third_party_graph/onnx_utils.py | 25 +- .../third_party_graph/optimizer.py | 144 ++++ .../third_party_graph/pytorch_graph.py | 691 --------------- .../third_party_graph/pytorch_graph_node.py | 236 ------ .../third_party_graph/pytorch_graph_parser.py | 66 +- .../third_party_graph/torch_utils.py | 105 --- .../test_name_mgr.py | 4 +- .../hierarchical_tree/__init__.py | 15 - .../test_hierarchical_tree.py | 177 ---- .../graph_based_converter/mapper/__init__.py | 4 +- .../mapper/test_mapper.py | 11 +- 28 files changed, 312 insertions(+), 2200 deletions(-) rename mindinsight/mindconverter/graph_based_converter/{hierarchical_tree => common}/name_mgr.py (98%) delete mode 100644 mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py delete mode 100644 mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py create mode 100644 mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py delete mode 100644 mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py delete mode 100644 mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py delete mode 100644 mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py rename tests/ut/mindconverter/graph_based_converter/{hierarchical_tree => common}/test_name_mgr.py (89%) delete mode 100644 tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py delete mode 100644 tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/common/name_mgr.py similarity index 98% rename from mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py rename to mindinsight/mindconverter/graph_based_converter/common/name_mgr.py index f43e7a0a..7546674b 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/common/name_mgr.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index b9e86f6e..62514e7f 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -194,6 +194,9 @@ def convert_bytes_string_to_string(bytes_str): def get_framework_type(model_path): """Get framework type.""" + if model_path.endswith('.onnx'): + return FrameworkType.PYTORCH.value + try: with open(model_path, 'rb') as f: if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 31b70d26..882b825d 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,10 +24,12 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver save_code_file_and_report, get_framework_type from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER +from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError +from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory permissions = os.R_OK | os.W_OK | os.X_OK os.umask(permissions << 3 | permissions) @@ -62,6 +64,7 @@ def torch_installation_validation(func): """ def _f(graph_path: str, sample_shape: tuple, + input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): # Check whether pytorch is installed. if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): @@ -93,6 +96,7 @@ def torch_installation_validation(func): sys.exit(0) func(graph_path=graph_path, sample_shape=sample_shape, + input_nodes=input_nodes, output_nodes=output_nodes, output_folder=output_folder, report_folder=report_folder) return _f @@ -182,6 +186,7 @@ def _extract_model_name(model_path): @SourceFilesSaveError.uniform_catcher() @GeneratorError.uniform_catcher() def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, + input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): """ PyTorch to MindSpore based on Graph. @@ -189,26 +194,18 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, Args: graph_path (str): Graph file path. sample_shape (tuple): Input shape of the model. + input_nodes (str): Input node(s) of the model. + output_nodes (str): Output node(s) of the model. output_folder (str): Output folder. report_folder (str): Report output folder path. - """ - third_party_graph_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.third_party_graph') - hierarchical_tree_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') - cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') - cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') - - graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) - - hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) + graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, + input_nodes=input_nodes, output_nodes=output_nodes) + generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) - - hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, - model_name=model_name, - report_folder=report_folder) + code_fragments = generator_inst.generate() + save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) @tf_installation_validation @@ -230,18 +227,13 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, output_nodes(str): Output node(s) of the model. output_folder(str): Output folder. report_folder(str): Report output folder path. - """ - third_party_graph_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.third_party_graph') - cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') - batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'), - "batch_add_nodes") + # Close unnecessary log. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, - input_nodes=input_nodes, output_nodes=output_nodes) + graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, + input_nodes=input_nodes, output_nodes=output_nodes) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) code_fragments = generator_inst.generate() @@ -255,7 +247,6 @@ def main_graph_base_converter(file_config): Args: file_config (dict): The config of file which to convert. - """ graph_path = file_config['model_file'] frame_type = get_framework_type(graph_path) @@ -263,8 +254,12 @@ def main_graph_base_converter(file_config): raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") if frame_type == FrameworkType.PYTORCH.value: + check_params = ['input_nodes', 'output_nodes'] + check_params_exist(check_params, file_config) graph_based_converter_pytorch_to_ms(graph_path=graph_path, sample_shape=file_config['shape'], + input_nodes=file_config['input_nodes'], + output_nodes=file_config['output_nodes'], output_folder=file_config['outfile_dir'], report_folder=file_config['report_dir']) elif frame_type == FrameworkType.TENSORFLOW.value: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py index 082596e0..b744da7f 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/args_translator.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -213,10 +213,11 @@ class ArgsTranslationHelper: Returns: list, name of args to be formal. """ + ret = list() if len(args_translators) < 2: # only one args_translator provided, no formal args. - return None - ret = [] + return ret + base_args_t = args_translators[0] for arg_name, arg_val in base_args_t.actual_args.items(): for args_t in args_translators[1:]: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index d18449fe..8c28a128 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from .module_struct import ModuleStruct from .args_translator import ArgsTranslationHelper from ..common.global_context import GlobalContext from ...common.exceptions import GeneratorError -from ..hierarchical_tree.name_mgr import GlobalVarNameMgr +from ..common.name_mgr import GlobalVarNameMgr from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module from ..report_generator import ReportGenerator diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 03a6dda9..7c96cce9 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from ..common.utils import get_dict_key_by_value from .args_translator import ArgsTranslation from ..common.code_fragment import ModuleFragment from ..common.global_context import GlobalContext -from ..hierarchical_tree.name_mgr import LocalVarNameMgr +from ..common.name_mgr import LocalVarNameMgr class ModuleStruct: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index 62e70ba1..8c99681a 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ from collections import OrderedDict from .scope_utils import Scope from .args_translator import ArgsTranslation from ..common.code_fragment import CodeFragment -from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..common.global_context import GlobalContext from ..constant import InputType @@ -110,11 +109,6 @@ class NodeStruct: self.graph_node_ref = gn self.scope_name = gn.scope_name - def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): - """Update basic info from PyTorchGraphNode.""" - self.node_type = "PyTorchGraphNode" - self._update_basics_from_gn(gn) - def _update_from_onnx_gn(self, gn: OnnxGraphNode): """Update basic info from OnnxGraphNode.""" self.node_type = "OnnxGraphNode" @@ -177,9 +171,8 @@ class NodeStruct: arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. force_ready (bool): Force this NodeStruct is ready to generate. """ - if isinstance(arg, PyTorchGraphNode): - self._update_from_pytorch_gn(arg) - elif isinstance(arg, OnnxGraphNode): + + if isinstance(arg, OnnxGraphNode): self._update_from_onnx_gn(arg) elif isinstance(arg, (dict, OrderedDict)): self._update_from_mapper(arg) @@ -246,7 +239,6 @@ class NodeStruct: """Return the output variable name of current node.""" return "{}_opt".format(self.ms_var_name).lower() - @property def args_translator(self): """Return the args translator of this Node.""" diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py deleted file mode 100644 index fb10aed7..00000000 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hierarchical tree module.""" -__all__ = ["HierarchicalTreeFactory"] - -import re - -from mindinsight.mindconverter.common.log import logger as log -from .hierarchical_tree import HierarchicalTree -from ..third_party_graph.onnx_graph_node import OnnxGraphNode - -from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError - - -def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): - """ - Rename the node name by combining scope name and its original name. - - Args: - node (OnnxGraphNode): OnnxGraphNode instance. - node_name (str): node name saved in Graph. - - Returns: - str, re-formatted node name. - """ - scope_name = node.scope_name - new_name = None - regex = r"(?P.+/)(?P\w+)" - match = re.match(regex, scope_name) - parent = match.group("parent") - node_name = '$' + node_name.replace('/', '::') + '$' - - if scope_name: - new_name = parent + node_name - if new_name: - return new_name - return node_name - - -class HierarchicalTreeFactory: - """Hierarchical tree factory.""" - - @classmethod - @TreeNodeInsertError.check_except("Tree node inserts failed.") - def create(cls, graph): - """ - Factory method of hierarchical tree. - - Args: - graph: Graph obj. - - Returns: - HierarchicalTree, tree. - """ - tree = HierarchicalTree() - node_scope_name = dict() - for _, node_name in enumerate(graph.nodes_in_topological_order): - node_inst = graph.get_node(node_name) - node_input = graph.get_input_shape(node_name) - node_output = graph.get_output_shape(node_name) - if node_input != 0 and not node_input: - err_msg = f"This model is not supported now. " \ - f"Cannot find {node_name}'s input shape." - error = NodeInputMissingError(err_msg) - log.error(str(error)) - raise error - if isinstance(node_inst, OnnxGraphNode): - node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) - node_scope_name[node_name] = node_name_with_scope - node_name = node_name_with_scope - - node_inst.add_input_and_output_shape(node_input, node_output) - tree.insert(node_inst, node_name) - - if node_scope_name: - return tree, node_scope_name - return tree diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py deleted file mode 100644 index b7c848a4..00000000 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ /dev/null @@ -1,796 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Define hierarchical tree.""" -import os -import stat -from copy import deepcopy -from typing import NoReturn, Union -from queue import Queue - -from yapf.yapflib.yapf_api import FormatCode -from treelib import Tree, Node - -from mindinsight.mindconverter.common.log import logger as log - -from .name_mgr import ModuleNameMgr, GlobalVarNameMgr -from ..common.utils import is_converted, save_code_file_and_report -from ..mapper.base import Mapper -from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode -from ..third_party_graph.onnx_graph_node import OnnxGraphNode -from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS -from ..constant import CodeFormatConfig -from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT -from ..constant import NEW_LINE, SECOND_LEVEL_INDENT -from ..constant import NodeType -from ..report_generator import ReportGenerator -from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError - - -class HierarchicalTree(Tree): - """Define hierarchical tree.""" - flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL - modes = stat.S_IRUSR | stat.S_IWUSR - modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR - - _root_created = False - ROOT_LEVEL = 0 - - GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() - - def __init__(self): - super(HierarchicalTree, self).__init__() - self._hierarchical_order = dict() - # Manage mapping of unique key and module name. - self._merged_module = dict() - # Manage mapping of unique key and module args. - self._merged_module_args = dict() - # Record creation of module with unique key. - self._created_module = dict() - # Manage module name to used. - self._module_mgr = ModuleNameMgr() - # Manage variable name in a module. - self._vars_mgr_in_module = dict() - self._module_vars = dict() - # scope name mapping record for easy node searching - self._scope_name_map = dict() - self.code_fragment_recorder = dict() - - @property - def tree_identifier(self): - """ - Return identifier of tree. - - Returns: - tree, id of tree. - """ - return self.identifier - - def get_node(self, nid): - """Override get_node method to support tf ver. generated scope.""" - if nid is None or not self.contains(nid): - if self._scope_name_map and nid in self._scope_name_map: - nid = self._scope_name_map.get(nid) - else: - return None - return self._nodes[nid] - - def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): - """ - Insert node into hierarchical tree. - - Args: - node_name (str): Node name. - node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. - - """ - scopes = node_name.split(SEPARATOR_IN_SCOPE) - for idx, scope in enumerate(scopes): - parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) - identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) - try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ - if parent else scope - if self.contains(try_parent): - # Whether current node existed. - parent = try_parent - - if not parent and not self._root_created: - # If no root node, then create it and mark it. - parent = None - self._root_created = True - elif not parent and self._root_created: - # Already have root node, skip it. - continue - - if not self.contains(identifier): - # Insert node into tree. - if isinstance(node, OnnxGraphNode): - tgt_node = node if idx == len( - scopes) - 1 else OnnxGraphNode() - else: - tgt_node = node if idx == len( - scopes) - 1 else PyTorchGraphNode() - tgt_node.successor_nodes = node.successor_nodes - tgt_node.precursor_nodes = node.precursor_nodes - tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 - else NodeType.MODULE).value - tgt_node.variable_name = self._get_var_name(identifier) - self.create_node( - tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], - identifier=identifier, - parent=parent, - data=tgt_node - ) - - def remove(self, node: Node, keep_sub=False): - """ - Remove node into hierarchical tree. - - Args: - node (Node): Node to be removed. - keep_sub (bool): Whether keep sub-tree. - - """ - if not keep_sub: - self.remove_node(node.identifier) - return - - def shrink(self, node: Node): - """ - Shrink sub-tree into one node. - - Use child node to replace its ancestor. - - Args: - node (Node): List of nodes to be merged. - - """ - node_name = node.identifier - parent_node = self[node.predecessor(self.tree_identifier)] - # Keep successors of parent. - brothers = deepcopy(parent_node.successors(self.tree_identifier)) - # Because shrink occurs when node has only one child, - # so we take index-0. - child = node.successors(self.tree_identifier)[0] - self.move_node(source=child, - destination=node.predecessor(self.tree_identifier)) - self.remove(node) - brothers[brothers.index(node_name)] = child - parent_node.set_successors(brothers, tree_id=self.tree_identifier) - - def save_source_files(self, out_folder: str, mapper: Mapper, - model_name: str, - report_folder: str = None, - scope_name_map: dict = None) -> NoReturn: - """ - Save source codes to target folder. - - Args: - report_folder (str): Report folder. - mapper (Mapper): Mapper of third party framework and mindspore. - model_name(str): Name of Converted model. - out_folder (str): Output folder. - scope_name_map(str): Scope name map of tensorflow. - - """ - if scope_name_map: - self._scope_name_map = scope_name_map - try: - self._adjust_structure() - code_fragments = self._generate_codes(mapper) - except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e: - log.error("Error occur when generating codes.") - raise e - - save_code_file_and_report(model_name, code_fragments, out_folder, report_folder) - - def _preprocess_node_args(self, node, module_key): - """ - Remove unused args. - - Args: - node (Node): Node instance. - module_key (str): Nodule key. - - Returns: - Node, node. - """ - if module_key in self._merged_module_args: - node = self._clear_unused_args( - node, self._merged_module_args[module_key]) - else: - node.data.clear_args_of_declaration() - return node - - def _postprocess_node_args(self, node, precursor_module_key): - """ - Post process args in node. - - Args: - node (Node): Node instance. - precursor_module_key (str): Parent node module name. - - Returns: - Node, node. - """ - if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, - NodeType.FUNC.value}: - # If current node is class or function, then - # remove unused args in __init__. - cur_module_key = node.data.hash_key or self.hash_key(node) - if cur_module_key in self._merged_module_args: - node = self._clear_unused_args(node, - self._merged_module_args[cur_module_key]) - - # `self._merged_module_args` records formal args. - # We need to replace actual args. - if precursor_module_key in self._merged_module_args: - # If parent node is in `_merged_module_args`, then - # replace current node args with arg name declared - # in _merged_module_args. - for arg in node.data.args_in_code.keys(): - if arg in self._merged_module_args[precursor_module_key]: - node.data.replace_with_arg(arg, arg) - return node - - def _clear_unused_args(self, node, used_args): - """ - Clear unused args. - - Args: - node (Node): Node. - used_args (list): Args list. - - Returns: - Node, node instance. - """ - args_in_code = list(node.data.args_in_code.keys()) - for arg in args_in_code: - ori_arg = arg.replace( - f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" - ) - if ori_arg not in used_args: - node.data.args_in_code.pop(arg) - return node - - @ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.") - @ReportGenerationError.check_except("Not find valid operators in converted script.") - def _generate_codes(self, mapper): - """ - Generate code files. - - - 1. Generate args. - - 2. Merge module. - - 3. Pre-process node args. - - 4. Post-process child node args. - - 5. Generate class/func code. - - 6. Merge code snippets. - - Args: - mapper (Mapper): Mapper of third party operation and mindspore. - - Returns: - Dict, codes. - """ - code_blocks = [get_imported_module()] - depths = sorted(list(self._hierarchical_order.keys()), reverse=True) - - for depth in depths: - node_collection = self._hierarchical_order[depth] - for node_name in node_collection: - # Traverse nodes in topological order. - node = self.get_node(node_name) - # 1. Generate args for each node in this level. - if node.data.node_type == NodeType.MODULE.value: - self._create_module_args_and_vars(node, mapper) - if depth == depths[-1]: - self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") - - # Module merging based on all nodes. - self._module_merging() - - for depth in depths: - node_collection = self._hierarchical_order[depth] - snippets = set() - for node_name in node_collection: - nd_inst = self.get_node(node_name) - if nd_inst.data.node_type != NodeType.MODULE.value: - continue - - # Generate hash key for node. - module_key = nd_inst.data.hash_key - # Get code generation func. - func, node_type = self._fetch_func_and_type(nd_inst) - - if module_key in self._created_module: - # If the module has already been created, - # then assign the created module name to current node, - # and delete unused args. - module_name = self._created_module[module_key] - self.code_fragment_recorder[nd_inst.identifier].operation = module_name - self.code_fragment_recorder[nd_inst.identifier].node_type = node_type - self._preprocess_node_args(nd_inst, module_key) - continue - - module_name = nd_inst.tag - - if node_type == NodeType.CLASS.value: - module_name = f"{module_name[0].upper()}{module_name[1:]}" - - # After node_type and module_name is frozen, - # then it's unchangeable. - module_name = self._module_mgr.get_name(module_name) - self.code_fragment_recorder[nd_inst.identifier].operation = module_name - self.code_fragment_recorder[nd_inst.identifier].node_type = node_type - - # 3. Pre-process node args. - nd_inst = self._preprocess_node_args(nd_inst, module_key) - # 4. Post-process child node args. - for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): - self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) - # 5. Generate code. - snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) - - code_blocks.extend(snippets) - - if self._scope_name_map: # from tf. conversion - c_blocks = [] - for s in code_blocks: - s = s.replace('$', '') - c_blocks.append(s) - code_blocks = c_blocks - - formatted_code, _ = FormatCode("".join(code_blocks), - style_config=CodeFormatConfig.PEP8.value) - - report_generator = ReportGenerator() - report = report_generator.gen_report(formatted_code) - - return {"model": (formatted_code, report)} - - def _fetch_func_and_type(self, node) -> Union[object, str]: - """ - Generate code snippet. - - Args: - node (Node): Node. - - Returns: - Union[object, str], code snippet func. - """ - - def _is_func(): - """ - The correct thought is to check whether have more than one - path in this block. - """ - nonlocal node - - if node.predecessor(self.tree_identifier) is None: - return False - - tgt_type = {NodeType.MODULE.value, - NodeType.FUNC.value, NodeType.CLASS.value} - md_type_lst = [self.get_node(child).data.node_type - for child in node.successors(self.tree_identifier)] - diff_set = set(md_type_lst) - tgt_type - return not diff_set - - if _is_func(): - return self._generate_func_snippet, NodeType.FUNC.value - return self._generate_class_snippet, NodeType.CLASS.value - - def _generate_func_snippet(self, node, func_name, func_key): - """ - Generate function snippet. - - Args: - node (Node): Node inst. - - Returns: - str, code snippet. - """ - definition = "" - - if func_key.lower() in self._merged_module_args and \ - self._merged_module_args[func_key.lower()]: - definition = ", ".join(self._merged_module_args[func_key.lower()]) - - module_list = [] - for node_name in node.successors(self.tree_identifier): - c_nd = self.get_node(node_name) - operator = self.code_fragment_recorder[c_nd.identifier].operation - - if c_nd.data.node_type != NodeType.OPERATION.value: - hash_key = c_nd.data.hash_key or self.hash_key(c_nd) - if hash_key in self._created_module: - operator = self._created_module[hash_key] - - args = c_nd.data.args_in_code - if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( - self.code_fragment_recorder[c_nd.identifier].operation): - args.update({"input_shape": c_nd.data.input_shape, - "output_shape": c_nd.data.output_shape}) - - # Generate code statement. - expr = ", ".join( - [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" - for k, v in args.items()] - ) - code_line = f"{operator}({expr})" - module_list.append(code_line) - - body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) - snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \ - f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \ - f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \ - f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)" - definition = f"def {func_name}({definition}):{NEW_LINE}" - - # Mark the structure has been created. - self._created_module[func_key.lower()] = func_name - - return f"{definition}{snippet}{NEW_LINE * 3}" - - def _generate_class_snippet(self, node, class_name, class_key): - """ - Generate class-type code snippet. - - Args: - node (Node): Node. - - Returns: - str, code snippet. - """ - super_call = f"super({class_name}, self).__init__()" - - if class_key.lower() in self._merged_module_args and \ - self._merged_module_args[class_key.lower()]: - args = f"{', '.join(self._merged_module_args[class_key.lower()])}" - - class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ - f"{args}):" \ - f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ - f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" - else: - class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \ - f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" - - init_block = [] - construct_block = [] - - for idx, node_name in enumerate(node.successors(self.tree_identifier)): - nd_inst = self.get_node(node_name) - if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: - continue - - # Generate code statement. - init, construct = self._generate_stat(nd_inst, node, idx) - - # support multiple construct and init block returns: - if isinstance(construct, list): - construct_block += construct - else: - construct_block.append(construct) - - if isinstance(init, list): - init_block += init - else: - init_block.append(init) - - class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ - f"{NEW_LINE}{SECOND_LEVEL_INDENT}" - init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) - csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) - csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" - - cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}" - - # Mark the structure has been created. - self._created_module[class_key.lower()] = class_name - - return f"{cls_definition}" \ - f"{class_init}" \ - f"{init_body}{NEW_LINE}" \ - f"{class_construct}" \ - f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}" - - def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx): - """ - Generate statements. - - Args: - cur_nd_inst (Node): Current node instance. - pre_nd_inst (Node): Precursor node instance. - idx (int): Index of cur node. - - Returns: - Tuple[str, str], declare in init and call in construct. - """ - - ipt_args_in_construct = "x" - opt_arg_in_construct = ["output"] - - if idx != 0: - if cur_nd_inst.data.is_in_multi_opt_graph: - ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst) - else: - # Get previous node output variable name. - ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) - if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: - # Set opt variable name. - if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph: - opt_arg_in_construct = [ - f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" - ] - else: - opt_arg_in_construct = [ - f"opt_{var_name}" - for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name - ] - - declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, - variable_name=self.code_fragment_recorder[ - cur_nd_inst.identifier].declared_var_name, - output_var=opt_arg_in_construct, - code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) - - return declare, call - - @staticmethod - def _get_var_name(s): - """ - Get variable name using scope name. - - Args: - s (str): String. - - Returns: - str, variable name. - """ - return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] - - def _get_current_ipt_var(self, cur_nd): - """" - Get current input variable name from node_id. - - Args: - cur_nd (Node): Current node. - - Returns: - str, needed var names. - """ - if cur_nd.data.node_type != NodeType.OPERATION.value: - while True: - p_nd = cur_nd.successors(self.tree_identifier) - if not p_nd: - break - cur_nd = self.get_node(p_nd[0]) - - ipt_lst_raw = [] - for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs: - ipt_lst_raw.append(f"{operation_input}") - - opt_var_names_p_nds = set() - for e in cur_nd.data.precursor_nodes: - p_nd = self.get_node(e) - if p_nd.data.op_name in NO_CONVERTED_OPERATORS: - continue - - opt_var_names_p_nd = set(p_nd.data.opt_var_names) - opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd) - - ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)] - return ", ".join(ipt_lst) - - def _find_all_previous_opt_var_(self, cur_nd, pre_nd): - """ - Find all input variable names. - - Args: - cur_nd (Node): Current node. - pre_nd (Node): Precursor node. - - Returns: - list, needed var names list. - """ - ipt_lst = [] - if cur_nd.tag in NO_CONVERTED_OPERATORS: - return ipt_lst - - for e in cur_nd.data.precursor_nodes: - p_nd = self.get_node(e) - if e not in pre_nd.successors(self.tree_identifier): - while True: - if p_nd.identifier in pre_nd.successors(self.tree_identifier): - ipt_lst.append( - f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" - ) - break - pre_nd_name = p_nd.predecessor(self.tree_identifier) - if not pre_nd_name: - ipt_lst.append("x") - break - p_nd = self.get_node(pre_nd_name) - continue - ipt_lst.append( - f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" - ) - return ipt_lst - - def _get_previous_opt_var(self, cur_nd, pre_nd): - """ - Get needed input variable names. - - Args: - cur_nd (Node): Current node. - pre_nd (Node): Precursor node. - - Returns: - str, needed var names. - """ - if cur_nd.data.node_type != NodeType.OPERATION.value: - while True: - p_nd = cur_nd.successors(self.tree_identifier) - if not p_nd: - break - cur_nd = self.get_node(p_nd[0]) - return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) - - def hash_key(self, node): - """ - Generate hash key for each node. - - Args: - node (Node): Node. - - Returns: - str, hash key. - """ - scsr_topo_order = [] - for s in node.successors(self.tree_identifier): - cur_nd = self.get_node(s) - if cur_nd.data.node_type in {NodeType.MODULE.value, - NodeType.FUNC.value, - NodeType.CLASS.value}: - if cur_nd.data.hash_key: - scsr_topo_order.append(f"({cur_nd.data.hash_key})") - continue - - raise ValueError("Current node doesn't have hash key.") - - if cur_nd.data.hash_key: - scsr_topo_order.append(cur_nd.data.hash_key) - continue - unique_key = "->".join(scsr_topo_order) - node.data.hash_key = unique_key - return unique_key - - def _module_merging(self): - """Generate sub-module and corresponding params.""" - merged_module_args = dict() - for module_key, module_args in self._merged_module.items(): - if module_key not in merged_module_args: - merged_module_args[module_key] = [] - # Take first element's args as base. - keys = module_args[0].keys() - for key in keys: - for i in range(1, len(module_args)): - if key in module_args[i] and module_args[0][key] != module_args[i][key]: - merged_module_args[module_key].append(key) - break - if key not in module_args[i]: - merged_module_args[module_key].append(key) - break - - self._merged_module_args.update(merged_module_args) - - def _create_module_args_and_vars(self, node, mapper): - """ - Create module args and variables in current node. - - Args: - node (Node): Node on tree. - mapper (Mapper): Mapper of params. - - """ - # All args and value pair in current node module. - module_args = dict() - module_key = self.hash_key(node) - created = False - - if module_key not in self._vars_mgr_in_module: - self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR - self._module_vars[module_key] = [] - else: - created = True - - # Sub-modules in the module could have arg name conflicts. - for idx, successor_name in enumerate(node.successors(self.tree_identifier)): - nd_inst = self.get_node(successor_name) - if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: - continue - - # Generation of params must behind variable assigment. - if created: - variable_name = self._module_vars[module_key][idx] - else: - variable_name = nd_inst.data.op_name or nd_inst.tag - variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) - - code_fragment = nd_inst.data.param_transform(mapper, variable_name) - code_fragment.declared_var_name = variable_name - code_fragment.output_var_name = nd_inst.data.opt_var_names - code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names) - self.code_fragment_recorder[nd_inst.identifier] = code_fragment - - module_args.update(nd_inst.data.args_in_code) - - if not created: - self._module_vars[module_key].append(variable_name) - - node.data.args_in_code = module_args - - # Collect module args of `module_key`. - if module_key not in self._merged_module: - self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] - else: - self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) - - @staticmethod - def _create_operation_args(node, mapper): - """ - Create operation args. - - Args: - node (Node): Node on tree. - mapper (Mapper): Mapper of params. - - """ - node.data.param_transform(mapper) - - def update_hierarchical_order(self) -> NoReturn: - """ - Update hierarchical order. - """ - hierarchical_order = dict() - queue = Queue() - queue.put(item=(self.root, self.ROOT_LEVEL), block=False) - while not queue.empty(): - node_name, cur_level = queue.get(block=False) - node_inst = self[node_name] - if cur_level not in hierarchical_order: - hierarchical_order[cur_level] = [] - hierarchical_order[cur_level].append(node_name) - for successor_name in node_inst.successors(self.tree_identifier): - queue.put(item=(successor_name, cur_level + 1), block=False) - self._hierarchical_order = hierarchical_order - - def sub_graph_merging(self) -> NoReturn: - """Shrink the module has only one child.""" - self.update_hierarchical_order() - depths = sorted(list(self._hierarchical_order.keys()), reverse=True) - for depth in depths: - for node_name in self._hierarchical_order[depth]: - node_inst = self[node_name] - # If the node type is module and has only one child, - # then merge it with its child. - if node_inst.data.node_type == NodeType.MODULE.value and \ - len(node_inst.successors(self.tree_identifier)) == 1: - self.shrink(node_inst) - - def _adjust_structure(self) -> NoReturn: - """Adjust tree structure to generate source code.""" - self.sub_graph_merging() - self.update_hierarchical_order() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index 8bd6c96b..ec85b839 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -171,3 +171,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): outputs_list = [f"opt_{{{variable_slot}}}"] outputs_mapping = ((0, 0),) return template, exchange_msg, outputs_list, outputs_mapping + + @staticmethod + def _find_val_by_index(loc_index, values_dict): + """Find value by location index of values_dict.""" + result = None + for idx, dict_val in enumerate(values_dict.values()): + if idx == loc_index: + result = dict_val + break + return result diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 1cb09a61..3100f555 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -42,7 +42,7 @@ class ConvMapper(ONNXToMindSporeMapper): """Convert params from PyTorch to MindSpore""" weights = kwargs['weights'] params = kwargs['params'] - weight = weights['weight'].numpy() + weight = weights['weight'] weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) if isinstance(params['dilations'], list): dilation = tuple(params['dilations']) @@ -130,7 +130,6 @@ class ConvMapper(ONNXToMindSporeMapper): dim = len(kernel_size) return f"nn.Conv{dim}d" - weight = weight.numpy() dim = weight.ndim - 2 return f"nn.Conv{dim}d" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index b11240d5..2c1cb219 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Mapper module.""" +import numpy as np from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting @@ -27,8 +28,11 @@ class DenseMapper(ONNXToMindSporeMapper): @staticmethod def _convert_params(**kwargs): weights = kwargs['weights'] - has_bias = bool('bias' in weights) - weight = weights['weight'].numpy().transpose() + weight_index = 0 + bias_index = 1 + bias = DenseMapper._find_val_by_index(bias_index, weights) + has_bias = isinstance(bias, np.ndarray) + weight = DenseMapper._find_val_by_index(weight_index, weights).transpose() in_channels, out_channels = weight.shape return { 'in_channels': in_channels, diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py index d99fcff5..6dedee16 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting @@ -47,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): def _convert_params(**kwargs): weights = kwargs.get("weights") params = kwargs.get("params") - mode = params.get('mode', 'constant') + mode = convert_bytes_string_to_string(params.get('mode', 'constant')) pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() if mode == 'constant' and params.get('value') is None: if params.get('pads') or weights: diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py index f3baa01d..feae93d1 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py @@ -36,7 +36,7 @@ class PoolMapper(ONNXToMindSporeMapper): transformed_params["kernel_size"] = tuple(params['kernel_shape']) transformed_params["stride"] = tuple(params['strides']) if "pads" in params: - if sum(params['pads']) == 0: + if sum(params['pads']) == 0 and not params.get('ceil_mode', None): pad_mode = '\"valid\"' else: pad_mode = '\"same\"' diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py index 9dd5ba22..55424443 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,14 +39,9 @@ class GraphFactory: Returns: Graph, graph instance. """ - if all([input_nodes, output_nodes]): - onnx_graph_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') - onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') - return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, - output_nodes=output_nodes, sample_shape=sample_shape) - - pytorch_graph_module = import_module( - 'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph') - pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph') - return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape) + + onnx_graph_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph') + onnx_graph = getattr(onnx_graph_module, 'OnnxGraph') + return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes, + output_nodes=output_nodes, sample_shape=sample_shape) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index af6e5c69..29153e54 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -264,7 +264,7 @@ class Graph(BaseGraph, abc.ABC): Returns: cls, graph instance. """ - src_graph = cls.load_graph(graph_path=model_path, **kwargs) + src_graph = cls.load_graph(graph_path=model_path, sample_shape=sample_shape, **kwargs) ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None if ckpt is not None: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 3ccf088e..74b6d4fc 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,14 @@ # limitations under the License. # ============================================================================== """Define ONNX graph.""" +from importlib import import_module from typing import Dict, NoReturn from mindinsight.mindconverter.common.log import logger as log from .base import Graph from .input_node import InputNode from .onnx_graph_node import OnnxGraphNode +from .pytorch_graph_parser import PyTorchGraphParser from .tf_graph_parser import TFGraphParser from .onnx_utils import OnnxDataLoader @@ -151,7 +153,7 @@ class OnnxGraph(Graph): input_shape (tuple): Input shape. """ input_node = InputNode(input_shape) - input_node_name = self._raw_input_nodes.replace(":0", "") + input_node_name = self._raw_input_nodes for node_name, node in self._nodes_collection.items(): if node_name in self._input_nodes: ipt_nd_name = input_node_name.format(input_node.scope_name) @@ -196,7 +198,18 @@ class OnnxGraph(Graph): """ tf_input_nodes = kwargs.get('input_nodes') tf_output_nodes = kwargs.get('output_nodes') - onnx_model = TFGraphParser.parse(graph_path, - input_nodes=tf_input_nodes, - output_nodes=tf_output_nodes) + if graph_path.endswith('.pb'): + onnx_model = TFGraphParser.parse(graph_path, + input_nodes=tf_input_nodes, + output_nodes=tf_output_nodes) + elif graph_path.endswith('.onnx'): + onnx = import_module('onnx') + onnx_model = onnx.load(graph_path) + optimizer = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer') + onnx_simplify = getattr(optimizer, 'OnnxSimplify')() + onnx_model = onnx_simplify.run_onnx_simplify(onnx_model, kwargs['sample_shape']) + + else: + onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) return onnx_model diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index e24ce7a8..3f40855f 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -112,10 +112,10 @@ class OnnxTensor: def to_array(self): """Convert the tensor value from binary to np array.""" - onnx = import_module("onnx") + numpy_helper = import_module("onnx.numpy_helper") # Convert binary data to np.array if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): - return onnx.numpy_helper.to_array(self.raw_tensor) + return numpy_helper.to_array(self.raw_tensor) return self.raw_tensor @@ -383,15 +383,24 @@ class OnnxDataLoader: """Parse each onnx nodes in the model.""" nodes_topo_idx = [] for idx, node in enumerate(self.nodes): + if not node.name: + node.name = "_".join(node.output) n = OnnxNode(node) self._nodes_dict[n.name] = n nodes_topo_idx.append((idx, n.name)) if len(node.output) > 1: raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") self.output_name_to_node_name[node.output[0]] = node.name + + for ipt_nd in node.input: + if ipt_nd not in self.output_name_to_node_name: + if self._global_context.onnx_node_inputs.get(n.name): + self._global_context.onnx_node_inputs[n.name].append(ipt_nd) + else: + self._global_context.onnx_node_inputs[n.name] = [ipt_nd] + self._global_context.onnx_node_name_to_topo_idx[n.name] = idx - node_inputs = [i.replace(":0", "") for i in node.input] - self._global_context.onnx_node_inputs[n.name] = node_inputs + self._global_context.onnx_nodes_collection = self._nodes_dict self._global_context.onnx_nodes_topo_index = nodes_topo_idx @@ -449,7 +458,11 @@ class OnnxDataLoader: input_node = self.get_node(input_node_name) node.precursor_onnx_node_dict[input_node_name] = input_node input_node.successor_onnx_node_dict[node_name] = node - continue + + if self._global_context.onnx_node_inputs.get(node.name): + self._global_context.onnx_node_inputs[node.name].append(input_node_name) + else: + self._global_context.onnx_node_inputs[node.name] = [input_node_name] def initialize(self): """Initialize the OnnxDataLoader.""" diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py new file mode 100644 index 00000000..63c0476f --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py @@ -0,0 +1,144 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define ONNX optimizer operations.""" +import copy +from importlib import import_module + +import numpy as np + +from ..common.utils import fetch_output_from_onnx_model + + +class OnnxSimplify: + """To simplify onnx model.""" + def __init__(self): + self._onnx_model = None + self._constant_nodes = list() + self._outputs_infer = dict() + + def run_onnx_simplify(self, onnx_model, sample_shape): + """ + Run to simplify onnx model. + + Args: + onnx_model (onnx.ModelProto): Onnx Model. + sample_shape (tuple): Sample shape of input. + """ + self._onnx_model = onnx_model + self._optimizer() + self._get_constant_nodes() + self._onnx_infer(sample_shape) + self._replace_constant_nodes() + self._optimizer() + + return self._onnx_model + + def _optimizer(self): + """Run optimizer from onnx to eliminate constant nodes.""" + + onnxoptimizer = import_module('onnxoptimizer') + optimizers_list = [ + 'eliminate_deadend', + 'eliminate_nop_dropout', + 'eliminate_nop_cast', + 'eliminate_nop_monotone_argmax', + 'eliminate_nop_pad', + 'extract_constant_to_initializer', + 'eliminate_unused_initializer', + 'eliminate_nop_transpose', + 'eliminate_identity', + 'fuse_add_bias_into_conv', + 'fuse_consecutive_concats', + 'fuse_consecutive_log_softmax', + 'fuse_consecutive_reduce_unsqueeze', + 'fuse_consecutive_squeezes', + 'fuse_consecutive_transposes', + 'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv', + 'fuse_transpose_into_gemm' + ] + + input_num = len(self._onnx_model.graph.input) + onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) + + if self._onnx_model.ir_version > 3: + del onnx_model_optimized.graph.input[input_num:] + self._onnx_model = onnx_model_optimized + + def _get_constant_nodes(self): + """Get constant nodes.""" + + const_nodes = list() + const_tensors = [tensor_init.name for tensor_init in self._onnx_model.graph.initializer] + const_tensors.append([node.output[0] + for node in self._onnx_model.graph.node if node.op_type == 'Constant']) + + for node in self._onnx_model.graph.node: + if node.op_type == 'Shape' or all([input_node in const_tensors for input_node in node.input]): + const_nodes.append(node) + const_tensors.extend(node.output) + + self._constant_nodes = copy.deepcopy(const_nodes) + + def _onnx_infer(self, infer_inputs_shape): + """ + Run onnx inference to get outputs of constant nodes. + + Args: + infer_inputs_shape (tuple): Input shape for running inference. + """ + + input_onnx = self._onnx_model.graph.input[0] + input_onnx_name = input_onnx.name + feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} + + output_nodes_name = list() + for node in self._constant_nodes: + output_nodes_name.extend(node.output) + + self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name) + + def _replace_constant_nodes(self): + """Replace constant nodes to nodes with op_type 'Constant'.""" + + onnx = import_module('onnx') + np_helper = import_module('onnx.numpy_helper') + + for i, node in enumerate(self._onnx_model.graph.node): + if node in self._constant_nodes: + for output in node.output: + new_attr = onnx.helper.make_attribute( + 'value', + np_helper.from_array(self._outputs_infer[output], name=output) + ) + + new_node = onnx.helper.make_node( + op_type='Constant', + inputs=list(), + outputs=[output], + name='_'.join(('node', output)) + ) + new_node.attribute.extend([new_attr]) + self._insert_node(self._onnx_model.graph.node, i + 1, new_node) + del self._onnx_model.graph.node[i] + + @staticmethod + def _insert_node(repeated_container, index, node): + """Insert node into onnx model.""" + + repeated_container.extend([repeated_container[-1]]) + for i in reversed(range(index + 1, len(repeated_container) - 1)): + repeated_container[i].CopyFrom(repeated_container[i - 1]) + repeated_container[index].CopyFrom(node) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py deleted file mode 100644 index 3cc84216..00000000 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ /dev/null @@ -1,691 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Define PyTorch graph.""" -import os -import re -import warnings -from copy import deepcopy -from importlib import import_module -from typing import Dict, NoReturn - -import numpy as np - -from mindinsight.conf import settings -from mindinsight.mindconverter.common.log import logger as log -from .base import Graph -from .input_node import InputNode -from .pytorch_graph_node import PyTorchGraphNode -from .pytorch_graph_parser import PyTorchGraphParser -from .torch_utils import set_opset_version -from ..common.utils import fetch_output_from_onnx_model - -from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ - MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME -from ..constant import LEFT_BUCKET, RIGHT_BUCKET -from ...common.exceptions import ModelNotSupportError - -NONE_SCOPE_OP = { - "onnx::Add": "Add", - "onnx::Flatten": "Flatten", - "onnx::Concat": "Concat", - "onnx::Squeeze": "Squeeze", - "onnx::Unsqueeze": "Unsqueeze", - "onnx::Split": "Split", - "onnx::Reshape": "Reshape", - "onnx::Transpose": "Transpose", - "onnx::Constant": "Constant", - "onnx::ReduceMean": "ReduceMean", - "onnx::Resize": "Resize", - "onnx::Pad": "Pad" -} - -CONSTANT_NODES_PATTERN = { - "onnx::Resize": [ - 'onnx::Concat', - 'onnx::Slice', - 'onnx::Cast', - 'onnx::Concat', - 'onnx::Unsqueeze', - 'onnx::Floor', - 'onnx::Mul', - 'onnx::Cast', - 'onnx::Gather', - 'onnx::Shape' - ], - "onnx::Pad": [ - 'onnx::Cast', - 'onnx::Concat', - 'onnx::ConstantOfShape', - 'onnx::Sub', - 'onnx::Mul', - 'onnx::Div', - 'onnx::Gather', - 'onnx::Shape', - 'onnx::Unsqueeze', - 'onnx::Slice', - 'onnx::Reshape', - 'onnx::Transpose' - ], - "onnx::Constant": list() -} - - -def normalize_scope_name(node, scope_name_dict): - """ - Rename scope name into uniform. - - Args: - node (Node): PyTorch node. - scope_name_dict (dict): Dictionary of scope names with the key node_id. - - Returns: - str, normalized scope name. - """ - global NONE_SCOPE_OP - - scope_name = node.scopeName() - if not scope_name: - name = [retrieve_scope_name(node, scope_name_dict)] - else: - name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) - scopes = [] - for segment in name: - segment = segment.split(LINK_IN_SCOPE)[0] - left = segment.find(LEFT_BUCKET) - right = segment.find(RIGHT_BUCKET) - if left != -1: - if segment[left + 1: right].isdigit(): - scopes.append(f"{segment[:left]}_{segment[left + 1: right]}") - else: - scopes.append(segment[left + 1: right]) - else: - scopes.append(segment) - if node.kind() in NONE_SCOPE_OP.keys(): - scopes.append(NONE_SCOPE_OP[node.kind()]) - scopes = [s for s in scopes if s] - node_id = PyTorchGraph.get_node_id(node) - return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}" - - -def retrieve_scope_name(node, scope_name_dict): - """ - Retrieve scope name from input nodes. - - Args: - node (Node): PyTorch node. - scope_name_dict (dict): Dictionary of scope names with the key node_id. - - Return: - str: Scope name. - """ - node_content = \ - SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]) - node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] - node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") - - scope_name_ipt_nodes = list() - for node_input in node_inputs: - if not scope_name_dict.get(node_input, None): - continue - scope_name_ipt_nodes.append(scope_name_dict[node_input]) - - scope_name_split = list() - for idx, _ in enumerate(scope_name_ipt_nodes): - if not scope_name_split: - scope_name_split = scope_name_ipt_nodes[idx] - else: - scope_name_split = [ - sub_scope_name - for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx] - ] - scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split) - return scope_name - - -class PyTorchGraph(Graph): - """ - Define PyTorch graph. - - Args: - model (Module): PyTorch model. - sample_shape (tuple): Input shape of the model. - - """ - - def __init__(self, model, sample_shape: tuple): - super(PyTorchGraph, self).__init__(model=model) - - from .torch_utils import unique_state_dict - - self._params_dict = unique_state_dict(model) - self._original_shape = list() - self._nodes = list() - self._constant_nodes = list() - self._dynamic_nodes = list() - self._has_eliminated_nodes = False - self._file_graph_onnx = os.path.join( - settings.WORKSPACE, 'log/mindconverter/' - ) - - self.build(sample_shape) - - @staticmethod - def _check_input_shape(input_shape): - """ - Check input shape. - - Args: - input_shape (tuple): Input tensor shape. - - """ - if not input_shape: - err_msg = "`input_shape` can not be None." - log.error(err_msg) - raise ValueError(err_msg) - - for item in input_shape: - if not isinstance(item, int): - err_msg = "Only support model with one input now, " \ - "and each shape value in `input_shape` should be int." - log.error(err_msg) - raise ValueError(err_msg) - - @staticmethod - def _extract_shape(shape): - """ - Extract shape from string-type shape. - - Args: - shape (str): Shape value in string-type. - - Returns: - list, shape. - """ - if "," not in shape: - return [] - - shape_arr = [] - for s in shape.split(","): - s = s.strip() - if not s: - return [] - if ":" in s: - s = s.split(":")[0] - s = s.replace("!", "") - if not s.isdigit(): - return [] - shape_arr.append(int(s)) - return shape_arr - - def _trace_torch_graph(self, input_shape): - """ - Trace torch computational graph. - - Args: - input_shape (tuple): Shape. - - Returns: - object, pytorch graph. - """ - import torch - from torch.onnx import OperatorExportTypes - from .torch_utils import OverloadTorchModuleTemporarily - from .torch_utils import create_autograd_variable - from .torch_utils import onnx_tracer - - warnings.simplefilter("ignore") - - batched_sample = create_autograd_variable(torch.rand(*input_shape)) - - try: - try: - # Assign execution mode to eval. - self.model.eval() - - with OverloadTorchModuleTemporarily() as _: - # In pytorch higher version, trace function has a known. - graph = onnx_tracer(self.model, batched_sample, - OperatorExportTypes.ONNX) - return graph - - except RuntimeError: - # Assign execution mode to eval. - self.model.eval() - - with OverloadTorchModuleTemporarily() as _: - # In pytorch higher version, trace function has a known. - set_opset_version(ONNX_OPSET_VERSION) - graph = onnx_tracer(self.model, batched_sample, - OperatorExportTypes.ONNX) - return graph - - except RuntimeError as error: - log.error(str(error)) - log.exception(error) - raise error - - def build(self, input_shape): - """ - Build graph tree. - - Args: - input_shape (tuple): Input shape of model. - - """ - self._check_input_shape(input_shape) - self._original_shape = input_shape - feed_forward_ipt_shape = tuple(input_shape) - graph = self._trace_torch_graph(feed_forward_ipt_shape) - nodes = list(graph.nodes()) - self._nodes = nodes - - scope_name_dict = dict() - - self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes) - - for node in nodes: - output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs())) - if output_name in self._dynamic_nodes: - continue - - node_name = normalize_scope_name(node, scope_name_dict) - scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ - = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) - output_shape_str_list = re.findall(r'[^()!]+', str(node)) - output_shape_str = output_shape_str_list[1] - output_shape = self._extract_shape(output_shape_str) - weight_scope = '.'.join( - re.findall(r'\[([\w\d.]+)]', node.scopeName()) - ) - - if self._constant_nodes: - node_weight = self._replace_constant_node(node) - else: - node_weight = {} - - for scope, weight in self._params_dict.items(): - split_scope = scope.split('.') - if '.'.join(split_scope[:-1]) == weight_scope: - node_weight[split_scope[-1]] = weight - - if not node_weight and node.kind() == 'onnx::Conv': - weight_names = list(self._params_dict.keys()) - node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()] - for node_input_name in node_input_names: - if int(node_input_name) > len(weight_names): - continue - weight = self._params_dict[weight_names[int(node_input_name) - 1]] - node_weight[weight_names[int(node_input_name) - 1]] = weight - - self._shape_dict[node_name] = output_shape - self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight) - self._nodes_record[node_name] = node_name - - for node_input in list(node.inputs()): - if self._extract_input_name(node_input) in self._constant_nodes: - continue - # Connect input node and src node. - nd_id = PyTorchGraph.get_node_id(node_input.node()) - nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ - node_input.node().scopeName() - - if nd_id and nd_scope_name: - node_input_name = normalize_scope_name( - node_input.node(), scope_name_dict - ) - self.build_connection(node_input_name, node_name) - - self._unmerge_multi_ipt_opt_script() - - super(PyTorchGraph, self).build(input_shape=input_shape) - self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) - - @staticmethod - def _extract_node_name(node): - """Extract node name for node.""" - result = re.match(r"\d+", str(node)) - if result: - return result.group(0) - return None - - @staticmethod - def _extract_input_name(node_input): - """Extract node input name from node input.""" - node_input_name = str(node_input).split('defined in')[0].strip() - return node_input_name - - def _get_constant_nodes(self, nodes): - """ - Get constant nodes to be eliminated. - - Args: - nodes (Nodes): Nodes in torch._C.Graph. - - Returns: - Union(dict, list), output of constant_input_node_name and dynamic nodes name. - """ - constant_input_nodes = list() - dynamic_nodes = list() - for node in nodes: - if node.kind() == 'onnx::Resize': - self._has_eliminated_nodes = True - constant_input_node, dynamic_node = self._generate_inputs_of(node) - constant_input_nodes += constant_input_node - dynamic_nodes += dynamic_node - - outputs = dict() - if self._has_eliminated_nodes: - torch = import_module('torch') - device_target = 'cuda' if torch.cuda.is_available() else 'cpu' - dump_input = torch.randn(*self._original_shape, device=device_target) - temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx, - '.graph_onnx.onnx')) - - symbolic_helper = import_module('torch.onnx.symbolic_helper') - export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version') - try: - torch.onnx.export(self.model.to(device_target), dump_input, - temp_onnx_path, opset_version=export_onnx_opset_version) - - outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape) - finally: - if os.path.exists(temp_onnx_path): - os.remove(temp_onnx_path) - - return outputs, dynamic_nodes - - def _generate_inputs_of(self, node): - """ - Generate inputs of certain node. - - Args: - node (Node): Node of torch._C.Graph. - - """ - pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None) - constant_input_nodes = list() - dynamic_nodes = list() - if not isinstance(pattern_op_lst, list): - return constant_input_nodes, dynamic_nodes - if not pattern_op_lst: - dynamic_nodes += self.get_node_id(node) - return constant_input_nodes, dynamic_nodes - - node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] - - for node_input_name in node_inputs_name: - node_name_path = self._search_node_path(node_input_name, pattern_op_lst) - if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape': - constant_input_nodes.append(node_input_name) - dynamic_nodes += node_name_path - - return constant_input_nodes, dynamic_nodes - - def _search_node_path(self, node_name, pattern_op_lst): - """ - Search node path based on pattern_op_list. - - Args: - node_name (str): Node name. - pattern_op_lst (list): Pattern list of certain operator. - - Returns: - list[str]: node names in pattern. - """ - node_type_lst = list() - node_name_lst = list() - node = self._get_node_from_graph(node_name) - - if node_name == MODEL_INPUT_NAME: - return node_name_lst - - if node.kind() not in pattern_op_lst: - return node_name_lst - - node_type_lst.append(node.kind()) - node_name_lst.append(node_name) - - node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] - for node_input_name in node_inputs_name: - node_name_lst += self._search_node_path(node_input_name, pattern_op_lst) - - return node_name_lst - - def _get_node_from_graph(self, node_name): - """Get torch._C.Node from torch._C.Graph.""" - for idx, node in enumerate(self._nodes): - node_id = ', '.join(self.get_node_id(node)) - if node_id == node_name: - return self._nodes[idx] - return None - - @staticmethod - def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape): - """ - Infer onnx model to get outputs of inner nodes. - - Args: - file_graph_onnx (str): File path of onnx. - infer_outputs (list): Outputs for infer. - infer_inputs_shape (list): Input shape for infer. - - """ - onnx = import_module('onnx') - tensor_proto = getattr(onnx, 'TensorProto') - onnx_model = onnx.load(file_graph_onnx) - - for onnx_node in onnx_model.graph.node: - if set(onnx_node.output).issubset(set(infer_outputs)): - onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output]) - - input_onnx = onnx_model.graph.input[0] - node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type) - if node_type != 'FLOAT': - raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. " - f"Please report issue to us if extra input type is needed.") - - input_onnx_name = input_onnx.name - feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} - outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs) - - return outputs - - def _replace_constant_node(self, node): - """Replace constant node.""" - node_weight = dict() - for node_input in list(node.inputs()): - node_input_name = self._extract_input_name(node_input) - if node_input_name in self._constant_nodes: - node_weight[node_input_name] = self._constant_nodes[node_input_name] - return node_weight - - def _collect_ipt_shape_of_each_node(self, input_shape): - """ - Collect input tensor shape of each node. - - Args: - input_shape (tuple): Input shape. - - """ - input_node = InputNode(input_shape) - input_node_name = "{}InputNode" - for node_name, node in self._nodes_collection.items(): - if node_name in self._input_nodes: - ipt_nd_name = input_node_name.format(input_node.scope_name) - input_node.set_scope_name(node.scope_name) - node.precursor_nodes.insert(0, ipt_nd_name) - input_node.set_successor_nodes(node_name) - self._shape_dict[ipt_nd_name] = input_node.output_shape - - if not self._shape_dict[node_name]: - self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE - - ipt_shape = [] - for p_nd in node.precursor_nodes: - shp = self._shape_dict.get(p_nd) - ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) - - self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape - - def _generate_module(self): - """Generate modules.""" - module_dict = dict() - for node_key, _ in self._nodes_collection.items(): - node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE) - if len(node_key_in_scope) < MIN_SCOPE_LENGTH: - continue - - for idx in range(1, len(node_key_in_scope)): - node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx]) - node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1]) - if not module_dict.get(node_key_module, None): - module_dict[node_key_module] = {node_name} - else: - module_dict[node_key_module].add(node_name) - - return module_dict - - def _check_multi_ipt_opt(self): - """Check whether multi-input exists.""" - module_dict = self._generate_module() - for _, nodes_per_module in module_dict.items(): - prcs_nodes_out_from_module = set() - for node_name in nodes_per_module: - if re.search(r"[\d]+[&][\d]+", node_name): - self._is_multi_opt_graph = True - return True - - node = self._nodes_collection.get(node_name, None) - if node: - prcs_nodes = node.precursor_nodes - else: - continue - - for prcs_node in prcs_nodes: - if prcs_node not in nodes_per_module: - prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1]) - if prcs_node_module not in nodes_per_module: - prcs_nodes_out_from_module.add(prcs_node) - - if len(prcs_nodes_out_from_module) > 1: - return True - - return False - - def _unmerge_multi_ipt_opt_script(self): - """Unmerge all submodule.""" - if self._check_multi_ipt_opt() or self._has_eliminated_nodes: - for node_key, node_inst in deepcopy(self._nodes_collection).items(): - prsc_nodes = node_inst.precursor_nodes - scsr_nodes = node_inst.successor_nodes - - node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph - - node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], - prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) - for prsc_node in deepcopy(prsc_nodes)] - node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0], - scsr_node.split(SEPARATOR_IN_SCOPE)[-1])) - for scsr_node in deepcopy(scsr_nodes)] - - reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], - node_key.split(SEPARATOR_IN_SCOPE)[-1])) - - del self._nodes_collection[node_key] - self._nodes_collection[reduce_node_key] = node_inst - - for node_key, shape in deepcopy(self._shape_dict).items(): - reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], - node_key.split(SEPARATOR_IN_SCOPE)[-1])) - - del self._shape_dict[node_key] - self._shape_dict[reduce_node_key] = shape - - def sub_graph_merging(self): - """ - Merge split operation into one. - """ - raise NotImplementedError() - - def build_connection(self, src, tgt) -> NoReturn: - """ - Build connection between source node and target node. - - Args: - src (str): Source node name. - tgt (str): Target node name. - - """ - # If src and tgt are the same node, src not in node_collection or - # tgt not in node_collection, then skip this edge. - if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: - if src.split(':')[0] not in self._nodes_collection: - log.warning("Graph construct a self-loop node %s. Ignored.", src) - return - if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: - self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) - if src not in self._nodes_collection[tgt].precursor_nodes: - self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src) - - @staticmethod - def load_checkpoint(ckpt_path: str) -> Dict: - """ - Load checkpoint. - - Args: - ckpt_path (str): Checkpoint file path. - - Returns: - dict, weights in model. - """ - - @staticmethod - def load_metadata(**kwargs): - """ - Load graph metadata. - """ - err_msg = "class `PyTorchGraph` has not implemented " \ - "`load_metadata()`." - log.error(err_msg) - raise NotImplementedError(err_msg) - - @staticmethod - def load_graph(graph_path: str, **kwargs): - """ - Load graph. - - Args: - graph_path (str): Graph path. - - Returns: - object, pytorch model. - """ - torch_model = PyTorchGraphParser.parse(graph_path) - return torch_model - - @staticmethod - def get_node_id(node): - """ - Get node id using regular expr. - - Args: - node (Node): PyTorch node. - - Returns: - str, node id. - """ - node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] - node_id = re.findall(r"[%](.*?) [:]", node_title) - return node_id diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py deleted file mode 100644 index 96692565..00000000 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Define PyTorch graph node.""" -import re - -from .base import GraphNode -from ..common.utils import is_converted - -from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ - SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT - - -class PyTorchGraphNode(GraphNode): - """ - PyTorch graph node. - - Args: - node (torch._C.Node): Node in raw PyTorch graph. - - """ - - _type_frozen = False - _module_name_frozen = False - - def __init__(self, node=None, weight=None): - super(PyTorchGraphNode, self).__init__(node=node) - self._op_params = self._get_raw_params(node) - self._op_name = node.kind() if node else None - self._scope_name = node.scopeName() if node else None - self._weight = weight - self._ipt_var_names, self._opt_var_names \ - = self._extract_ipt_opt_var_names() if node else (list(), list()) - - def _extract_ipt_opt_var_names(self): - """Extract ipt and opt var names.""" - node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join( - str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:] - ) - node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] - node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") - node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] - node_outputs = re.findall(r"[%](.*?) [:]", node_title) - return node_inputs, node_outputs - - def clear_args_of_declaration(self): - """ - Clear `self._args_in_code`. - """ - self._args_in_code = dict() - - def _get_arg_name(self, arg, variable_name): - """ - Get arg name. - - Args: - arg (str): Generate arg name. - - Returns: - str, arg name in function or class declaration. - """ - return f"{arg}_{variable_name}" - - @property - def is_in_multi_opt_graph(self): - return self._is_in_multi_opt_graph - - @is_in_multi_opt_graph.setter - def is_in_multi_opt_graph(self, multi_opt_state): - self._is_in_multi_opt_graph = multi_opt_state - - @property - def hash_key(self): - """ - Return unique hash key of current node. - - Returns: - str, hash key. - """ - if self._node_type not in {NodeType.CLASS.value, - NodeType.FUNC.value, - NodeType.MODULE.value}: - self._hash_key = self._op_name.lower() - return self._hash_key - - @hash_key.setter - def hash_key(self, h): - """ - Setter of hash key. - - Args: - h (str): Key. - - """ - self._hash_key = h - - @property - def op_name(self): - """ - Op name in torch. - - Returns: - str, op name. - """ - return self._op_name - - @op_name.setter - def op_name(self, name): - """ - Setter of op name. - - Args: - name(str): op_name. - - """ - self._op_name = name - - @property - def real_name(self): - return - - def add_input_and_output_shape(self, input_shape, output_shape): - """ - Add the node input shape. - - Args: - output_shape (tuple): Output tensor shape. - input_shape (tuple): Input tensor shape. - - """ - self._ipt_shape = input_shape - self._opt_shape = output_shape - - def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment): - """ - Generate statements. - - Args: - variable_name (str): Variable name. - ipt_args_in_construct (str): Args of input. - output_var (list): Output variable names in construct. - code_fragment (CodeFragment): CodeFragment instance. - - Returns: - Union[str, str], declare in init and call in construct. - """ - operator = code_fragment.operation - - args = self.args_in_code - settings = code_fragment.code_setting - - if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): - args.update({"input_shape": self.input_shape, - "output_shape": self.output_shape}) - - if self._node_type == NodeType.OPERATION.value: - expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" - for k, v in args.items()]) - ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( - ipt_args_in_construct, settings) - else: - # When it's type is module, class or func, - # it's not necessary to replace var. - expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" - for k, v in args.items()]) - ipt_args_settings_in_construct = ipt_args_in_construct - - if SEPARATOR_IN_ONNX_OP in operator: - operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") - - declare = f"self.{variable_name} = {operator}({expr})" - call = f"{', '.join([output for output in output_var])}" \ - f" = self.{variable_name}({ipt_args_settings_in_construct})" - - return declare, call - - def to_ir(self): - """ - No need to implement for now. - """ - raise NotImplementedError - - def _get_raw_params(self, node): - """ - Get params in onnx. - - Args: - node (Any): Node. - - Returns: - dict, raw params. - """ - from .torch_utils import getitem_of_node - - raw_params = dict() - - if not node: - return raw_params - - for k in node.attributeNames(): - raw_params[k] = getitem_of_node(node, k) - return raw_params - - def replace_with_arg(self, src_arg, tgt_arg): - """ - Replace actual parameter with formal parameter. - - Args: - src_arg (str): Original arg name. - tgt_arg (str): Target arg name. - - """ - self._args_in_code[src_arg] = tgt_arg - - @staticmethod - def _extract_var_name(scope_name: str): - """ - Extract variable name from scope name. - """ - if not scope_name: - return None - var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower() - var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( - RIGHT_BUCKET, "") - return var diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py index 12741d43..09f2e9b1 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from importlib import import_module from mindinsight.mindconverter.common.log import logger as log from .base import GraphParser +from .optimizer import OnnxSimplify from ...common.exceptions import ModelNotSupportError @@ -38,7 +39,6 @@ class PyTorchGraphParser(GraphParser): Returns: object, torch model. """ - torch = import_module("torch") if not os.path.exists(model_path): error = FileNotFoundError("`model_path` must be assigned with " @@ -47,14 +47,66 @@ class PyTorchGraphParser(GraphParser): raise error try: - if torch.cuda.is_available(): - model = torch.load(f=model_path) - else: - model = torch.load(f=model_path, map_location="cpu") + onnx_model_sim = cls._convert_pytorch_graph_to_onnx( + model_path, kwargs['sample_shape'], opset_version=11) + return onnx_model_sim + except ModuleNotFoundError: error_msg = "Cannot find model scripts in system path, " \ "set `--project_path` to the path of model scripts folder correctly." error = ModuleNotFoundError(error_msg) raise error - return model + @staticmethod + def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None): + """ + Convert Pytorch model to ONNX model. + + Args: + model_path (str): Path to the Pytorch model. + sample_shape (tuple): Input shape to generate onnx model. + opset_version (int): Op set version of onnx. + """ + + torch = import_module('torch') + has_cuda = torch.cuda.is_available() + if has_cuda: + model = torch.load(f=model_path).cuda() + dump_input = torch.randn(*sample_shape, device='cuda') + else: + model = torch.load(f=model_path, map_location="cpu") + dump_input = torch.randn(*sample_shape, device='cpu') + + if isinstance(model, torch.nn.DataParallel): + raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.') + + torch_onnx = import_module('torch.onnx') + operator_export_types = getattr(torch_onnx, 'OperatorExportTypes') + utils = import_module('torch.onnx.utils') + model_to_graph = getattr(utils, '_model_to_graph') + + symbolic_helper = import_module('torch.onnx.symbolic_helper') + default_onnx_opset_version = getattr(symbolic_helper, '_default_onnx_opset_version') + set_opset_version = getattr(symbolic_helper, '_set_opset_version') + set_operator_export_type = getattr(symbolic_helper, '_set_operator_export_type') + if not opset_version: + opset_version = default_onnx_opset_version + + operator_export_type = operator_export_types.ONNX + set_opset_version(opset_version) + set_operator_export_type(operator_export_type) + + graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True) + export_onnx = getattr(graph, '_export_onnx') + proto, _ = export_onnx( + params_dict, opset_version, dict(), False, + operator_export_type, True, False, dict(), + True, False) + + onnx = import_module('onnx') + onnx_model = onnx.load_model_from_string(proto) + + onnx_simplify = OnnxSimplify() + onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape) + + return onnx_model_sim diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py deleted file mode 100644 index 623d4089..00000000 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Define pytorch tracer context manager.""" -import importlib - -from torch.nn import Module -from torch.onnx.utils import _trace -from torch.onnx.utils import _node_getitem -from torch.onnx.symbolic_helper import _set_opset_version - - -SCRIPT_METHOD = getattr(importlib.import_module("torch._C"), - "ScriptMethod") -onnx_tracer = _trace -getitem_of_node = _node_getitem -set_opset_version = _set_opset_version - - -def unique_state_dict(model): - """ - Wrapper of torch.jit._unique_state_dict. - - Args: - model (Module): Torch model. - - Returns: - dict, params. - """ - from torch.jit import _unique_state_dict - - return _unique_state_dict(model) - - -def create_autograd_variable(tensor): - """ - Create autograd variable to trace the whole graph. - - Args: - tensor (torch.Tensor): Tensor. - - Returns: - torch.autograd.Variable, variable. - """ - variable = getattr(importlib.import_module("torch.autograd"), "Variable") - return variable(tensor, requires_grad=False) - - -class OverloadTorchModuleTemporarily: - """ - Fix bugs in new version of pytorch. - PyTorch official solution. - """ - - def __init__(self): - self.backup = None - - def __enter__(self): - def _tracing_name(traced_module, tracing_state): - traced_module_stack = getattr(tracing_state, "_traced_module_stack") - if not traced_module_stack: - return None - module = traced_module_stack[-1] - for name, child in module.named_children(): - if child is traced_module: - return name - return None - - def _slow_forward(self_, *inputs, **kwargs): - tracing_state = getattr(importlib.import_module("torch._C"), - "_get_tracing_state")() - if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD): - return self_.forward(*inputs, **kwargs) - if not hasattr(tracing_state, '_traced_module_stack'): - tracing_state._traced_module_stack = [] - name = _tracing_name(self_, tracing_state) - get_name_func = getattr(self_, "_get_name") - if name: - tracing_state.push_scope('%s[%s]' % (get_name_func(), name)) - else: - tracing_state.push_scope(get_name_func()) - tracing_state._traced_module_stack.append(self_) - try: - result = self_.forward(*inputs, **kwargs) - finally: - tracing_state.pop_scope() - tracing_state._traced_module_stack.pop() - return result - - self.backup = getattr(Module, "_slow_forward") - setattr(Module, '_slow_forward', _slow_forward) - - def __exit__(self, exc_type, exc_val, exc_tb): - setattr(Module, '_slow_forward', self.backup) diff --git a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py b/tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py similarity index 89% rename from tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py rename to tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py index 7d47956f..666f2ef5 100644 --- a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py +++ b/tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # ============================================================================== """Test name manager module.""" from unittest import TestCase -from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \ +from mindinsight.mindconverter.graph_based_converter.common.name_mgr import NameMgr, GlobalVarNameMgr, \ global_op_namespace diff --git a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py deleted file mode 100644 index 60898242..00000000 --- a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for mindconvert.graph_based_converter.hierarchical_tree interface.""" diff --git a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py deleted file mode 100644 index 45775f86..00000000 --- a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Test hierarchical tree module.""" -import os -import shutil -from unittest import mock - -import pytest - -from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree -from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode -from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.constant import NodeType - -from tests.ut.mindconverter.graph_based_converter.conftest import TEST_BASE_PATH - - -class TestHierarchicalTree: - """Test the class of HierarchicalTree.""" - - def test_tree_identifier(self): - """Test tree_identifier""" - tree = HierarchicalTree() - assert isinstance(tree.tree_identifier, str) - - @mock.patch( - '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) - def test_insert(self, get_raw_params): - """Test insert""" - get_raw_params.return_value = [] - tree = HierarchicalTree() - pt_node = PyTorchGraphNode() - tree.insert(pt_node, 'ResNet') - assert tree.root == 'ResNet' - - def test_remove(self): - """Test remove function.""" - tree = HierarchicalTree() - tree.create_node( - tag='node_root', - identifier='root', - parent=None, - data=None - ) - node = tree.get_node('root') - tree.remove(node) - assert tree.root is None - - @mock.patch( - '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) - def test_shrink(self, get_raw_params): - """Test shrink function.""" - params = {'root': {}, - 'root/child0': {}, - 'root/child0/child1': {}} - tree = self._create_tree(get_raw_params=get_raw_params, params=params) - node = tree.get_node('root/child0') - tree.shrink(node) - assert tree.leaves()[0].tag == 'child1' - - @pytest.mark.parametrize('params', [{ - 'tree_params': {'root': {'op_name': 'Root', - 'precursor_nodes': [], - 'successor_nodes': ['root/relu'], - 'node_type': NodeType.MODULE.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 1, 224, 224]}, - 'root/relu': {'op_name': 'onnx::Relu', - 'precursor_nodes': ['root'], - 'successor_nodes': ['root/unknown'], - 'node_type': NodeType.OPERATION.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 3, 224, 224]}, - 'root/unknown': {'op_name': 'onnx::Unknown', - 'precursor_nodes': ['root/relu'], - 'successor_nodes': [], - 'node_type': NodeType.OPERATION.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 1, 224, 224]}}, - 'report_dir': 'report_folder' - }, { - 'tree_params': {'root': {'op_name': 'Root', - 'precursor_nodes': [], - 'successor_nodes': ['root/relu'], - 'node_type': NodeType.MODULE.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 1, 224, 224]}, - 'root/relu': {'op_name': 'onnx::Relu', - 'precursor_nodes': ['root'], - 'successor_nodes': ['root/unknown'], - 'node_type': NodeType.OPERATION.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 3, 224, 224]}, - 'root/unknown': {'op_name': 'onnx::Unknown', - 'precursor_nodes': ['root/relu'], - 'successor_nodes': [], - 'node_type': NodeType.OPERATION.value, - 'input_shape': [1, 3, 224, 224], - 'output_shape': [1, 1, 224, 224]}}, - 'report_dir': None - }]) - @mock.patch( - '.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params'))) - def test_save_source_file(self, get_raw_params, params): - """Test save_source_file function.""" - tree_params = params['tree_params'] - out_folder = 'out_folder' - report_folder = params['report_dir'] - model_name = 'model_name' - mapper = ONNXToMindSporeMapper() - - tree = self._create_tree(get_raw_params=get_raw_params, params=tree_params) - tree.save_source_files(out_folder, mapper, model_name, report_folder) - - out_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) - report_folder_test = report_folder if report_folder else out_folder - report_path = os.path.realpath( - os.path.join(report_folder_test, f"report_of_{model_name}.txt")) - try: - assert os.path.exists(out_path) - assert os.path.exists(report_path) - with open(out_path, 'r') as out_r: - code = out_r.read() - assert 'nn.ReLU' in code - assert 'onnx.Unknown' in code - with open(report_path, 'r') as report_r: - report = report_r.read() - assert "[UnConvert] 'onnx::Unknown' didn't convert." in report - assert "Converted Rate: 50.00%." in report - finally: - shutil.rmtree(out_folder) - if report_folder: - shutil.rmtree(report_folder) - - @staticmethod - def _create_node(key, val, weight, input_shape, output_shape): - """Create node.""" - node = PyTorchGraphNode(weight=weight) - node.add_input_and_output_shape(input_shape, output_shape) - node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key - node.op_name = val['op_name'] if val.get('op_name') else None - node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else [] - node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else [] - node.node_type = val['node_type'] if val.get('node_type') else None - return node - - @staticmethod - def _create_tree(get_raw_params, params): - """Create tree.""" - tree = HierarchicalTree() - for key, val in params.items(): - input_shape = val['input_shape'] if val.get('input_shape') else [] - output_shape = val['output_shape'] if val.get('output_shape') else [] - get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict() - weight = val['weight'] if val.get('weight') else None - - node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape) - - tree.create_node( - tag=node.tag, - identifier=key, - parent='/'.join(key.split('/')[:-1]) if len(key.split('/')) > 1 else None, - data=node - ) - return tree diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/__init__.py b/tests/ut/mindconverter/graph_based_converter/mapper/__init__.py index 393ef6d6..3ef346aa 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/__init__.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Unit test for mindconvert.graph_based_converter.mapper interface.""" +"""Unit test for mindconverter.graph_based_converter.mapper interface.""" diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py index 1dd6aa51..eff77c6d 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ import pytest from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting -from tests.utils import mindspore class TestMappers: @@ -30,7 +29,7 @@ class TestMappers: 'group': 1, 'pads': [1, 2, 3, 4], 'strides': [1, 1]}, - 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}}, + 'weights': {'weight': np.zeros((64, 3, 1, 1), dtype=np.int32)}}, 'expected_output': {'converter_name': 'nn.Conv2d', 'converted_params': {'in_channels': 3, 'out_channels': 64, @@ -47,7 +46,7 @@ class TestMappers: 'group': 1, 'pads': [0, 0, 0, 0], 'strides': [1, 1]}, - 'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}}, + 'weights': {'weight': np.zeros((64, 3, 2, 2), dtype=np.int32)}}, 'expected_output': {'converter_name': 'nn.Conv2d', 'converted_params': {'in_channels': 3, 'out_channels': 64, @@ -61,8 +60,8 @@ class TestMappers: }, { 'input': {'op_name': 'onnx::Gemm', 'params': dict(), - 'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)), - 'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}}, + 'weights': {'weight': np.zeros((10, 3), dtype=np.int32), + 'bias': np.zeros((10, 1), dtype=np.int32)}}, 'expected_output': {'converter_name': 'nn.Dense', 'converted_params': {'in_channels': 3, 'out_channels': 10,