Browse Source

!1103 Rewrite pytorch parser module to use onnx model module.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
d85a5f0a17
28 changed files with 312 additions and 2200 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/name_mgr.py
  2. +3
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +20
    -25
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +4
    -3
      mindinsight/mindconverter/graph_based_converter/generator/args_translator.py
  5. +2
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +2
    -2
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +3
    -11
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  8. +0
    -89
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  9. +0
    -796
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  10. +10
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  11. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  12. +6
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  13. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  14. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  15. +7
    -12
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  16. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  17. +18
    -5
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  18. +19
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  19. +144
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  20. +0
    -691
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  21. +0
    -236
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  22. +59
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  23. +0
    -105
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py
  24. +2
    -2
      tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py
  25. +0
    -15
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  26. +0
    -177
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py
  27. +2
    -2
      tests/ut/mindconverter/graph_based_converter/mapper/__init__.py
  28. +5
    -6
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py

mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py → mindinsight/mindconverter/graph_based_converter/common/name_mgr.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

+ 3
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -194,6 +194,9 @@ def convert_bytes_string_to_string(bytes_str):


def get_framework_type(model_path): def get_framework_type(model_path):
"""Get framework type.""" """Get framework type."""
if model_path.endswith('.onnx'):
return FrameworkType.PYTORCH.value

try: try:
with open(model_path, 'rb') as f: with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:


+ 20
- 25
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -24,10 +24,12 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver
save_code_file_and_report, get_framework_type save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory


permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)
@@ -62,6 +64,7 @@ def torch_installation_validation(func):
""" """


def _f(graph_path: str, sample_shape: tuple, def _f(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
@@ -93,6 +96,7 @@ def torch_installation_validation(func):
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder) output_folder=output_folder, report_folder=report_folder)


return _f return _f
@@ -182,6 +186,7 @@ def _extract_model_name(model_path):
@SourceFilesSaveError.uniform_catcher() @SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher() @GeneratorError.uniform_catcher()
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
""" """
PyTorch to MindSpore based on Graph. PyTorch to MindSpore based on Graph.
@@ -189,26 +194,18 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
Args: Args:
graph_path (str): Graph file path. graph_path (str): Graph file path.
sample_shape (tuple): Input shape of the model. sample_shape (tuple): Input shape of the model.
input_nodes (str): Input node(s) of the model.
output_nodes (str): Output node(s) of the model.
output_folder (str): Output folder. output_folder (str): Output folder.
report_folder (str): Report output folder path. report_folder (str): Report output folder path.

""" """
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)

hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)


graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path) model_name = _extract_model_name(graph_path)

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)




@tf_installation_validation @tf_installation_validation
@@ -230,18 +227,13 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
output_nodes(str): Output node(s) of the model. output_nodes(str): Output node(s) of the model.
output_folder(str): Output folder. output_folder(str): Output folder.
report_folder(str): Report output folder path. report_folder(str): Report output folder path.

""" """
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'),
"batch_add_nodes")

# Close unnecessary log. # Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path) model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate() code_fragments = generator_inst.generate()
@@ -255,7 +247,6 @@ def main_graph_base_converter(file_config):


Args: Args:
file_config (dict): The config of file which to convert. file_config (dict): The config of file which to convert.

""" """
graph_path = file_config['model_file'] graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path) frame_type = get_framework_type(graph_path)
@@ -263,8 +254,12 @@ def main_graph_base_converter(file_config):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")


if frame_type == FrameworkType.PYTORCH.value: if frame_type == FrameworkType.PYTORCH.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_pytorch_to_ms(graph_path=graph_path, graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'], sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'], output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir']) report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value: elif frame_type == FrameworkType.TENSORFLOW.value:


+ 4
- 3
mindinsight/mindconverter/graph_based_converter/generator/args_translator.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -213,10 +213,11 @@ class ArgsTranslationHelper:
Returns: Returns:
list, name of args to be formal. list, name of args to be formal.
""" """
ret = list()
if len(args_translators) < 2: if len(args_translators) < 2:
# only one args_translator provided, no formal args. # only one args_translator provided, no formal args.
return None
ret = []
return ret
base_args_t = args_translators[0] base_args_t = args_translators[0]
for arg_name, arg_val in base_args_t.actual_args.items(): for arg_name, arg_val in base_args_t.actual_args.items():
for args_t in args_translators[1:]: for args_t in args_translators[1:]:


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError from ...common.exceptions import GeneratorError
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator from ..report_generator import ReportGenerator




+ 2
- 2
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ from ..common.utils import get_dict_key_by_value
from .args_translator import ArgsTranslation from .args_translator import ArgsTranslation
from ..common.code_fragment import ModuleFragment from ..common.code_fragment import ModuleFragment
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..hierarchical_tree.name_mgr import LocalVarNameMgr
from ..common.name_mgr import LocalVarNameMgr




class ModuleStruct: class ModuleStruct:


+ 3
- 11
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ from collections import OrderedDict
from .scope_utils import Scope from .scope_utils import Scope
from .args_translator import ArgsTranslation from .args_translator import ArgsTranslation
from ..common.code_fragment import CodeFragment from ..common.code_fragment import CodeFragment
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..constant import InputType from ..constant import InputType
@@ -110,11 +109,6 @@ class NodeStruct:
self.graph_node_ref = gn self.graph_node_ref = gn
self.scope_name = gn.scope_name self.scope_name = gn.scope_name


def _update_from_pytorch_gn(self, gn: PyTorchGraphNode):
"""Update basic info from PyTorchGraphNode."""
self.node_type = "PyTorchGraphNode"
self._update_basics_from_gn(gn)

def _update_from_onnx_gn(self, gn: OnnxGraphNode): def _update_from_onnx_gn(self, gn: OnnxGraphNode):
"""Update basic info from OnnxGraphNode.""" """Update basic info from OnnxGraphNode."""
self.node_type = "OnnxGraphNode" self.node_type = "OnnxGraphNode"
@@ -177,9 +171,8 @@ class NodeStruct:
arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.
force_ready (bool): Force this NodeStruct is ready to generate. force_ready (bool): Force this NodeStruct is ready to generate.
""" """
if isinstance(arg, PyTorchGraphNode):
self._update_from_pytorch_gn(arg)
elif isinstance(arg, OnnxGraphNode):

if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg) self._update_from_onnx_gn(arg)
elif isinstance(arg, (dict, OrderedDict)): elif isinstance(arg, (dict, OrderedDict)):
self._update_from_mapper(arg) self._update_from_mapper(arg)
@@ -246,7 +239,6 @@ class NodeStruct:
"""Return the output variable name of current node.""" """Return the output variable name of current node."""
return "{}_opt".format(self.ms_var_name).lower() return "{}_opt".format(self.ms_var_name).lower()



@property @property
def args_translator(self): def args_translator(self):
"""Return the args translator of this Node.""" """Return the args translator of this Node."""


+ 0
- 89
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -1,89 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hierarchical tree module."""
__all__ = ["HierarchicalTreeFactory"]

import re

from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode

from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError


def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
"""
Rename the node name by combining scope name and its original name.

Args:
node (OnnxGraphNode): OnnxGraphNode instance.
node_name (str): node name saved in Graph.

Returns:
str, re-formatted node name.
"""
scope_name = node.scope_name
new_name = None
regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name)
parent = match.group("parent")
node_name = '$' + node_name.replace('/', '::') + '$'

if scope_name:
new_name = parent + node_name
if new_name:
return new_name
return node_name


class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@classmethod
@TreeNodeInsertError.check_except("Tree node inserts failed.")
def create(cls, graph):
"""
Factory method of hierarchical tree.

Args:
graph: Graph obj.

Returns:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_scope_name = dict()
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_input = graph.get_input_shape(node_name)
node_output = graph.get_output_shape(node_name)
if node_input != 0 and not node_input:
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
error = NodeInputMissingError(err_msg)
log.error(str(error))
raise error
if isinstance(node_inst, OnnxGraphNode):
node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name)
node_scope_name[node_name] = node_name_with_scope
node_name = node_name_with_scope

node_inst.add_input_and_output_shape(node_input, node_output)
tree.insert(node_inst, node_name)

if node_scope_name:
return tree, node_scope_name
return tree

+ 0
- 796
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -1,796 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define hierarchical tree."""
import os
import stat
from copy import deepcopy
from typing import NoReturn, Union
from queue import Queue

from yapf.yapflib.yapf_api import FormatCode
from treelib import Tree, Node

from mindinsight.mindconverter.common.log import logger as log

from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS
from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
from ..report_generator import ReportGenerator
from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError


class HierarchicalTree(Tree):
"""Define hierarchical tree."""
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR
modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR

_root_created = False
ROOT_LEVEL = 0

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()

def __init__(self):
super(HierarchicalTree, self).__init__()
self._hierarchical_order = dict()
# Manage mapping of unique key and module name.
self._merged_module = dict()
# Manage mapping of unique key and module args.
self._merged_module_args = dict()
# Record creation of module with unique key.
self._created_module = dict()
# Manage module name to used.
self._module_mgr = ModuleNameMgr()
# Manage variable name in a module.
self._vars_mgr_in_module = dict()
self._module_vars = dict()
# scope name mapping record for easy node searching
self._scope_name_map = dict()
self.code_fragment_recorder = dict()

@property
def tree_identifier(self):
"""
Return identifier of tree.

Returns:
tree, id of tree.
"""
return self.identifier

def get_node(self, nid):
"""Override get_node method to support tf ver. generated scope."""
if nid is None or not self.contains(nid):
if self._scope_name_map and nid in self._scope_name_map:
nid = self._scope_name_map.get(nid)
else:
return None
return self._nodes[nid]

def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str):
"""
Insert node into hierarchical tree.

Args:
node_name (str): Node name.
node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted.

"""
scopes = node_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])
try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \
if parent else scope
if self.contains(try_parent):
# Whether current node existed.
parent = try_parent

if not parent and not self._root_created:
# If no root node, then create it and mark it.
parent = None
self._root_created = True
elif not parent and self._root_created:
# Already have root node, skip it.
continue

if not self.contains(identifier):
# Insert node into tree.
if isinstance(node, OnnxGraphNode):
tgt_node = node if idx == len(
scopes) - 1 else OnnxGraphNode()
else:
tgt_node = node if idx == len(
scopes) - 1 else PyTorchGraphNode()
tgt_node.successor_nodes = node.successor_nodes
tgt_node.precursor_nodes = node.precursor_nodes
tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1
else NodeType.MODULE).value
tgt_node.variable_name = self._get_var_name(identifier)
self.create_node(
tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0],
identifier=identifier,
parent=parent,
data=tgt_node
)

def remove(self, node: Node, keep_sub=False):
"""
Remove node into hierarchical tree.

Args:
node (Node): Node to be removed.
keep_sub (bool): Whether keep sub-tree.

"""
if not keep_sub:
self.remove_node(node.identifier)
return

def shrink(self, node: Node):
"""
Shrink sub-tree into one node.

Use child node to replace its ancestor.

Args:
node (Node): List of nodes to be merged.

"""
node_name = node.identifier
parent_node = self[node.predecessor(self.tree_identifier)]
# Keep successors of parent.
brothers = deepcopy(parent_node.successors(self.tree_identifier))
# Because shrink occurs when node has only one child,
# so we take index-0.
child = node.successors(self.tree_identifier)[0]
self.move_node(source=child,
destination=node.predecessor(self.tree_identifier))
self.remove(node)
brothers[brothers.index(node_name)] = child
parent_node.set_successors(brothers, tree_id=self.tree_identifier)

def save_source_files(self, out_folder: str, mapper: Mapper,
model_name: str,
report_folder: str = None,
scope_name_map: dict = None) -> NoReturn:
"""
Save source codes to target folder.

Args:
report_folder (str): Report folder.
mapper (Mapper): Mapper of third party framework and mindspore.
model_name(str): Name of Converted model.
out_folder (str): Output folder.
scope_name_map(str): Scope name map of tensorflow.

"""
if scope_name_map:
self._scope_name_map = scope_name_map
try:
self._adjust_structure()
code_fragments = self._generate_codes(mapper)
except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e:
log.error("Error occur when generating codes.")
raise e

save_code_file_and_report(model_name, code_fragments, out_folder, report_folder)

def _preprocess_node_args(self, node, module_key):
"""
Remove unused args.

Args:
node (Node): Node instance.
module_key (str): Nodule key.

Returns:
Node, node.
"""
if module_key in self._merged_module_args:
node = self._clear_unused_args(
node, self._merged_module_args[module_key])
else:
node.data.clear_args_of_declaration()
return node

def _postprocess_node_args(self, node, precursor_module_key):
"""
Post process args in node.

Args:
node (Node): Node instance.
precursor_module_key (str): Parent node module name.

Returns:
Node, node.
"""
if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value,
NodeType.FUNC.value}:
# If current node is class or function, then
# remove unused args in __init__.
cur_module_key = node.data.hash_key or self.hash_key(node)
if cur_module_key in self._merged_module_args:
node = self._clear_unused_args(node,
self._merged_module_args[cur_module_key])

# `self._merged_module_args` records formal args.
# We need to replace actual args.
if precursor_module_key in self._merged_module_args:
# If parent node is in `_merged_module_args`, then
# replace current node args with arg name declared
# in _merged_module_args.
for arg in node.data.args_in_code.keys():
if arg in self._merged_module_args[precursor_module_key]:
node.data.replace_with_arg(arg, arg)
return node

def _clear_unused_args(self, node, used_args):
"""
Clear unused args.

Args:
node (Node): Node.
used_args (list): Args list.

Returns:
Node, node instance.
"""
args_in_code = list(node.data.args_in_code.keys())
for arg in args_in_code:
ori_arg = arg.replace(
f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", ""
)
if ori_arg not in used_args:
node.data.args_in_code.pop(arg)
return node

@ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerationError.check_except("Not find valid operators in converted script.")
def _generate_codes(self, mapper):
"""
Generate code files.

- 1. Generate args.
- 2. Merge module.
- 3. Pre-process node args.
- 4. Post-process child node args.
- 5. Generate class/func code.
- 6. Merge code snippets.

Args:
mapper (Mapper): Mapper of third party operation and mindspore.

Returns:
Dict, codes.
"""
code_blocks = [get_imported_module()]
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)

for depth in depths:
node_collection = self._hierarchical_order[depth]
for node_name in node_collection:
# Traverse nodes in topological order.
node = self.get_node(node_name)
# 1. Generate args for each node in this level.
if node.data.node_type == NodeType.MODULE.value:
self._create_module_args_and_vars(node, mapper)
if depth == depths[-1]:
self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "")

# Module merging based on all nodes.
self._module_merging()

for depth in depths:
node_collection = self._hierarchical_order[depth]
snippets = set()
for node_name in node_collection:
nd_inst = self.get_node(node_name)
if nd_inst.data.node_type != NodeType.MODULE.value:
continue

# Generate hash key for node.
module_key = nd_inst.data.hash_key
# Get code generation func.
func, node_type = self._fetch_func_and_type(nd_inst)

if module_key in self._created_module:
# If the module has already been created,
# then assign the created module name to current node,
# and delete unused args.
module_name = self._created_module[module_key]
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type
self._preprocess_node_args(nd_inst, module_key)
continue

module_name = nd_inst.tag

if node_type == NodeType.CLASS.value:
module_name = f"{module_name[0].upper()}{module_name[1:]}"

# After node_type and module_name is frozen,
# then it's unchangeable.
module_name = self._module_mgr.get_name(module_name)
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type

# 3. Pre-process node args.
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)):
self._postprocess_node_args(self.get_node(scsr_nd_name), module_key)
# 5. Generate code.
snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key))

code_blocks.extend(snippets)

if self._scope_name_map: # from tf. conversion
c_blocks = []
for s in code_blocks:
s = s.replace('$', '')
c_blocks.append(s)
code_blocks = c_blocks

formatted_code, _ = FormatCode("".join(code_blocks),
style_config=CodeFormatConfig.PEP8.value)

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)

return {"model": (formatted_code, report)}

def _fetch_func_and_type(self, node) -> Union[object, str]:
"""
Generate code snippet.

Args:
node (Node): Node.

Returns:
Union[object, str], code snippet func.
"""

def _is_func():
"""
The correct thought is to check whether have more than one
path in this block.
"""
nonlocal node

if node.predecessor(self.tree_identifier) is None:
return False

tgt_type = {NodeType.MODULE.value,
NodeType.FUNC.value, NodeType.CLASS.value}
md_type_lst = [self.get_node(child).data.node_type
for child in node.successors(self.tree_identifier)]
diff_set = set(md_type_lst) - tgt_type
return not diff_set

if _is_func():
return self._generate_func_snippet, NodeType.FUNC.value
return self._generate_class_snippet, NodeType.CLASS.value

def _generate_func_snippet(self, node, func_name, func_key):
"""
Generate function snippet.

Args:
node (Node): Node inst.

Returns:
str, code snippet.
"""
definition = ""

if func_key.lower() in self._merged_module_args and \
self._merged_module_args[func_key.lower()]:
definition = ", ".join(self._merged_module_args[func_key.lower()])

module_list = []
for node_name in node.successors(self.tree_identifier):
c_nd = self.get_node(node_name)
operator = self.code_fragment_recorder[c_nd.identifier].operation

if c_nd.data.node_type != NodeType.OPERATION.value:
hash_key = c_nd.data.hash_key or self.hash_key(c_nd)
if hash_key in self._created_module:
operator = self._created_module[hash_key]

args = c_nd.data.args_in_code
if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted(
self.code_fragment_recorder[c_nd.identifier].operation):
args.update({"input_shape": c_nd.data.input_shape,
"output_shape": c_nd.data.output_shape})

# Generate code statement.
expr = ", ".join(
[f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}"
for k, v in args.items()]
)
code_line = f"{operator}({expr})"
module_list.append(code_line)

body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list)
snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \
f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)"
definition = f"def {func_name}({definition}):{NEW_LINE}"

# Mark the structure has been created.
self._created_module[func_key.lower()] = func_name

return f"{definition}{snippet}{NEW_LINE * 3}"

def _generate_class_snippet(self, node, class_name, class_key):
"""
Generate class-type code snippet.

Args:
node (Node): Node.

Returns:
str, code snippet.
"""
super_call = f"super({class_name}, self).__init__()"

if class_key.lower() in self._merged_module_args and \
self._merged_module_args[class_key.lower()]:
args = f"{', '.join(self._merged_module_args[class_key.lower()])}"

class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \
f"{args}):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"
else:
class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"

init_block = []
construct_block = []

for idx, node_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(node_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx)

# support multiple construct and init block returns:
if isinstance(construct, list):
construct_block += construct
else:
construct_block.append(construct)

if isinstance(init, list):
init_block += init
else:
init_block.append(init)

class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}"
init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block)
csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block)
csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}"

cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}"

# Mark the structure has been created.
self._created_module[class_key.lower()] = class_name

return f"{cls_definition}" \
f"{class_init}" \
f"{init_body}{NEW_LINE}" \
f"{class_construct}" \
f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}"

def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx):
"""
Generate statements.

Args:
cur_nd_inst (Node): Current node instance.
pre_nd_inst (Node): Precursor node instance.
idx (int): Index of cur node.

Returns:
Tuple[str, str], declare in init and call in construct.
"""

ipt_args_in_construct = "x"
opt_arg_in_construct = ["output"]

if idx != 0:
if cur_nd_inst.data.is_in_multi_opt_graph:
ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst)
else:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph:
opt_arg_in_construct = [
f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
]
else:
opt_arg_in_construct = [
f"opt_{var_name}"
for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name
]

declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
variable_name=self.code_fragment_recorder[
cur_nd_inst.identifier].declared_var_name,
output_var=opt_arg_in_construct,
code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier])

return declare, call

@staticmethod
def _get_var_name(s):
"""
Get variable name using scope name.

Args:
s (str): String.

Returns:
str, variable name.
"""
return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0]

def _get_current_ipt_var(self, cur_nd):
""""
Get current input variable name from node_id.

Args:
cur_nd (Node): Current node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])

ipt_lst_raw = []
for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs:
ipt_lst_raw.append(f"{operation_input}")

opt_var_names_p_nds = set()
for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if p_nd.data.op_name in NO_CONVERTED_OPERATORS:
continue

opt_var_names_p_nd = set(p_nd.data.opt_var_names)
opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd)

ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)]
return ", ".join(ipt_lst)

def _find_all_previous_opt_var_(self, cur_nd, pre_nd):
"""
Find all input variable names.

Args:
cur_nd (Node): Current node.
pre_nd (Node): Precursor node.

Returns:
list, needed var names list.
"""
ipt_lst = []
if cur_nd.tag in NO_CONVERTED_OPERATORS:
return ipt_lst

for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if e not in pre_nd.successors(self.tree_identifier):
while True:
if p_nd.identifier in pre_nd.successors(self.tree_identifier):
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
break
pre_nd_name = p_nd.predecessor(self.tree_identifier)
if not pre_nd_name:
ipt_lst.append("x")
break
p_nd = self.get_node(pre_nd_name)
continue
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
return ipt_lst

def _get_previous_opt_var(self, cur_nd, pre_nd):
"""
Get needed input variable names.

Args:
cur_nd (Node): Current node.
pre_nd (Node): Precursor node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])
return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd))

def hash_key(self, node):
"""
Generate hash key for each node.

Args:
node (Node): Node.

Returns:
str, hash key.
"""
scsr_topo_order = []
for s in node.successors(self.tree_identifier):
cur_nd = self.get_node(s)
if cur_nd.data.node_type in {NodeType.MODULE.value,
NodeType.FUNC.value,
NodeType.CLASS.value}:
if cur_nd.data.hash_key:
scsr_topo_order.append(f"({cur_nd.data.hash_key})")
continue

raise ValueError("Current node doesn't have hash key.")

if cur_nd.data.hash_key:
scsr_topo_order.append(cur_nd.data.hash_key)
continue
unique_key = "->".join(scsr_topo_order)
node.data.hash_key = unique_key
return unique_key

def _module_merging(self):
"""Generate sub-module and corresponding params."""
merged_module_args = dict()
for module_key, module_args in self._merged_module.items():
if module_key not in merged_module_args:
merged_module_args[module_key] = []
# Take first element's args as base.
keys = module_args[0].keys()
for key in keys:
for i in range(1, len(module_args)):
if key in module_args[i] and module_args[0][key] != module_args[i][key]:
merged_module_args[module_key].append(key)
break
if key not in module_args[i]:
merged_module_args[module_key].append(key)
break

self._merged_module_args.update(merged_module_args)

def _create_module_args_and_vars(self, node, mapper):
"""
Create module args and variables in current node.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
# All args and value pair in current node module.
module_args = dict()
module_key = self.hash_key(node)
created = False

if module_key not in self._vars_mgr_in_module:
self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR
self._module_vars[module_key] = []
else:
created = True

# Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generation of params must behind variable assigment.
if created:
variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.tag
variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name)

code_fragment = nd_inst.data.param_transform(mapper, variable_name)
code_fragment.declared_var_name = variable_name
code_fragment.output_var_name = nd_inst.data.opt_var_names
code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names)
self.code_fragment_recorder[nd_inst.identifier] = code_fragment

module_args.update(nd_inst.data.args_in_code)

if not created:
self._module_vars[module_key].append(variable_name)

node.data.args_in_code = module_args

# Collect module args of `module_key`.
if module_key not in self._merged_module:
self._merged_module[module_key] = [deepcopy(node.data.args_in_code)]
else:
self._merged_module[module_key].append(deepcopy(node.data.args_in_code))

@staticmethod
def _create_operation_args(node, mapper):
"""
Create operation args.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
node.data.param_transform(mapper)

def update_hierarchical_order(self) -> NoReturn:
"""
Update hierarchical order.
"""
hierarchical_order = dict()
queue = Queue()
queue.put(item=(self.root, self.ROOT_LEVEL), block=False)
while not queue.empty():
node_name, cur_level = queue.get(block=False)
node_inst = self[node_name]
if cur_level not in hierarchical_order:
hierarchical_order[cur_level] = []
hierarchical_order[cur_level].append(node_name)
for successor_name in node_inst.successors(self.tree_identifier):
queue.put(item=(successor_name, cur_level + 1), block=False)
self._hierarchical_order = hierarchical_order

def sub_graph_merging(self) -> NoReturn:
"""Shrink the module has only one child."""
self.update_hierarchical_order()
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)
for depth in depths:
for node_name in self._hierarchical_order[depth]:
node_inst = self[node_name]
# If the node type is module and has only one child,
# then merge it with its child.
if node_inst.data.node_type == NodeType.MODULE.value and \
len(node_inst.successors(self.tree_identifier)) == 1:
self.shrink(node_inst)

def _adjust_structure(self) -> NoReturn:
"""Adjust tree structure to generate source code."""
self.sub_graph_merging()
self.update_hierarchical_order()

+ 10
- 0
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -171,3 +171,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
outputs_list = [f"opt_{{{variable_slot}}}"] outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),) outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _find_val_by_index(loc_index, values_dict):
"""Find value by location index of values_dict."""
result = None
for idx, dict_val in enumerate(values_dict.values()):
if idx == loc_index:
result = dict_val
break
return result

+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -42,7 +42,7 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from PyTorch to MindSpore""" """Convert params from PyTorch to MindSpore"""
weights = kwargs['weights'] weights = kwargs['weights']
params = kwargs['params'] params = kwargs['params']
weight = weights['weight'].numpy()
weight = weights['weight']
weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0])
if isinstance(params['dilations'], list): if isinstance(params['dilations'], list):
dilation = tuple(params['dilations']) dilation = tuple(params['dilations'])
@@ -130,7 +130,6 @@ class ConvMapper(ONNXToMindSporeMapper):
dim = len(kernel_size) dim = len(kernel_size)
return f"nn.Conv{dim}d" return f"nn.Conv{dim}d"


weight = weight.numpy()
dim = weight.ndim - 2 dim = weight.ndim - 2
return f"nn.Conv{dim}d" return f"nn.Conv{dim}d"




+ 6
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


@@ -27,8 +28,11 @@ class DenseMapper(ONNXToMindSporeMapper):
@staticmethod @staticmethod
def _convert_params(**kwargs): def _convert_params(**kwargs):
weights = kwargs['weights'] weights = kwargs['weights']
has_bias = bool('bias' in weights)
weight = weights['weight'].numpy().transpose()
weight_index = 0
bias_index = 1
bias = DenseMapper._find_val_by_index(bias_index, weights)
has_bias = isinstance(bias, np.ndarray)
weight = DenseMapper._find_val_by_index(weight_index, weights).transpose()
in_channels, out_channels = weight.shape in_channels, out_channels = weight.shape
return { return {
'in_channels': in_channels, 'in_channels': in_channels,


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


@@ -47,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper):
def _convert_params(**kwargs): def _convert_params(**kwargs):
weights = kwargs.get("weights") weights = kwargs.get("weights")
params = kwargs.get("params") params = kwargs.get("params")
mode = params.get('mode', 'constant')
mode = convert_bytes_string_to_string(params.get('mode', 'constant'))
pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist()
if mode == 'constant' and params.get('value') is None: if mode == 'constant' and params.get('value') is None:
if params.get('pads') or weights: if params.get('pads') or weights:


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -36,7 +36,7 @@ class PoolMapper(ONNXToMindSporeMapper):
transformed_params["kernel_size"] = tuple(params['kernel_shape']) transformed_params["kernel_size"] = tuple(params['kernel_shape'])
transformed_params["stride"] = tuple(params['strides']) transformed_params["stride"] = tuple(params['strides'])
if "pads" in params: if "pads" in params:
if sum(params['pads']) == 0:
if sum(params['pads']) == 0 and not params.get('ceil_mode', None):
pad_mode = '\"valid\"' pad_mode = '\"valid\"'
else: else:
pad_mode = '\"same\"' pad_mode = '\"same\"'


+ 7
- 12
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -39,14 +39,9 @@ class GraphFactory:
Returns: Returns:
Graph, graph instance. Graph, graph instance.
""" """
if all([input_nodes, output_nodes]):
onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

pytorch_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph')
pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph')
return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape)

onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -264,7 +264,7 @@ class Graph(BaseGraph, abc.ABC):
Returns: Returns:
cls, graph instance. cls, graph instance.
""" """
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
src_graph = cls.load_graph(graph_path=model_path, sample_shape=sample_shape, **kwargs)
ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None


if ckpt is not None: if ckpt is not None:


+ 18
- 5
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Define ONNX graph.""" """Define ONNX graph."""
from importlib import import_module
from typing import Dict, NoReturn from typing import Dict, NoReturn


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import Graph from .base import Graph
from .input_node import InputNode from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode from .onnx_graph_node import OnnxGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .tf_graph_parser import TFGraphParser from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader from .onnx_utils import OnnxDataLoader


@@ -151,7 +153,7 @@ class OnnxGraph(Graph):
input_shape (tuple): Input shape. input_shape (tuple): Input shape.
""" """
input_node = InputNode(input_shape) input_node = InputNode(input_shape)
input_node_name = self._raw_input_nodes.replace(":0", "")
input_node_name = self._raw_input_nodes
for node_name, node in self._nodes_collection.items(): for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes: if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name) ipt_nd_name = input_node_name.format(input_node.scope_name)
@@ -196,7 +198,18 @@ class OnnxGraph(Graph):
""" """
tf_input_nodes = kwargs.get('input_nodes') tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes') tf_output_nodes = kwargs.get('output_nodes')
onnx_model = TFGraphParser.parse(graph_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
if graph_path.endswith('.pb'):
onnx_model = TFGraphParser.parse(graph_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
elif graph_path.endswith('.onnx'):
onnx = import_module('onnx')
onnx_model = onnx.load(graph_path)
optimizer = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer')
onnx_simplify = getattr(optimizer, 'OnnxSimplify')()
onnx_model = onnx_simplify.run_onnx_simplify(onnx_model, kwargs['sample_shape'])

else:
onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs)
return onnx_model return onnx_model

+ 19
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -112,10 +112,10 @@ class OnnxTensor:


def to_array(self): def to_array(self):
"""Convert the tensor value from binary to np array.""" """Convert the tensor value from binary to np array."""
onnx = import_module("onnx")
numpy_helper = import_module("onnx.numpy_helper")
# Convert binary data to np.array # Convert binary data to np.array
if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)):
return onnx.numpy_helper.to_array(self.raw_tensor)
return numpy_helper.to_array(self.raw_tensor)
return self.raw_tensor return self.raw_tensor




@@ -383,15 +383,24 @@ class OnnxDataLoader:
"""Parse each onnx nodes in the model.""" """Parse each onnx nodes in the model."""
nodes_topo_idx = [] nodes_topo_idx = []
for idx, node in enumerate(self.nodes): for idx, node in enumerate(self.nodes):
if not node.name:
node.name = "_".join(node.output)
n = OnnxNode(node) n = OnnxNode(node)
self._nodes_dict[n.name] = n self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name)) nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1: if len(node.output) > 1:
raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name self.output_name_to_node_name[node.output[0]] = node.name

for ipt_nd in node.input:
if ipt_nd not in self.output_name_to_node_name:
if self._global_context.onnx_node_inputs.get(n.name):
self._global_context.onnx_node_inputs[n.name].append(ipt_nd)
else:
self._global_context.onnx_node_inputs[n.name] = [ipt_nd]

self._global_context.onnx_node_name_to_topo_idx[n.name] = idx self._global_context.onnx_node_name_to_topo_idx[n.name] = idx
node_inputs = [i.replace(":0", "") for i in node.input]
self._global_context.onnx_node_inputs[n.name] = node_inputs

self._global_context.onnx_nodes_collection = self._nodes_dict self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx self._global_context.onnx_nodes_topo_index = nodes_topo_idx


@@ -449,7 +458,11 @@ class OnnxDataLoader:
input_node = self.get_node(input_node_name) input_node = self.get_node(input_node_name)
node.precursor_onnx_node_dict[input_node_name] = input_node node.precursor_onnx_node_dict[input_node_name] = input_node
input_node.successor_onnx_node_dict[node_name] = node input_node.successor_onnx_node_dict[node_name] = node
continue

if self._global_context.onnx_node_inputs.get(node.name):
self._global_context.onnx_node_inputs[node.name].append(input_node_name)
else:
self._global_context.onnx_node_inputs[node.name] = [input_node_name]


def initialize(self): def initialize(self):
"""Initialize the OnnxDataLoader.""" """Initialize the OnnxDataLoader."""


+ 144
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define ONNX optimizer operations."""
import copy
from importlib import import_module

import numpy as np

from ..common.utils import fetch_output_from_onnx_model


class OnnxSimplify:
"""To simplify onnx model."""
def __init__(self):
self._onnx_model = None
self._constant_nodes = list()
self._outputs_infer = dict()

def run_onnx_simplify(self, onnx_model, sample_shape):
"""
Run to simplify onnx model.

Args:
onnx_model (onnx.ModelProto): Onnx Model.
sample_shape (tuple): Sample shape of input.
"""
self._onnx_model = onnx_model
self._optimizer()
self._get_constant_nodes()
self._onnx_infer(sample_shape)
self._replace_constant_nodes()
self._optimizer()

return self._onnx_model

def _optimizer(self):
"""Run optimizer from onnx to eliminate constant nodes."""

onnxoptimizer = import_module('onnxoptimizer')
optimizers_list = [
'eliminate_deadend',
'eliminate_nop_dropout',
'eliminate_nop_cast',
'eliminate_nop_monotone_argmax',
'eliminate_nop_pad',
'extract_constant_to_initializer',
'eliminate_unused_initializer',
'eliminate_nop_transpose',
'eliminate_identity',
'fuse_add_bias_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_log_softmax',
'fuse_consecutive_reduce_unsqueeze',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
'fuse_transpose_into_gemm'
]

input_num = len(self._onnx_model.graph.input)
onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True)

if self._onnx_model.ir_version > 3:
del onnx_model_optimized.graph.input[input_num:]
self._onnx_model = onnx_model_optimized

def _get_constant_nodes(self):
"""Get constant nodes."""

const_nodes = list()
const_tensors = [tensor_init.name for tensor_init in self._onnx_model.graph.initializer]
const_tensors.append([node.output[0]
for node in self._onnx_model.graph.node if node.op_type == 'Constant'])

for node in self._onnx_model.graph.node:
if node.op_type == 'Shape' or all([input_node in const_tensors for input_node in node.input]):
const_nodes.append(node)
const_tensors.extend(node.output)

self._constant_nodes = copy.deepcopy(const_nodes)

def _onnx_infer(self, infer_inputs_shape):
"""
Run onnx inference to get outputs of constant nodes.

Args:
infer_inputs_shape (tuple): Input shape for running inference.
"""

input_onnx = self._onnx_model.graph.input[0]
input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}

output_nodes_name = list()
for node in self._constant_nodes:
output_nodes_name.extend(node.output)

self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name)

def _replace_constant_nodes(self):
"""Replace constant nodes to nodes with op_type 'Constant'."""

onnx = import_module('onnx')
np_helper = import_module('onnx.numpy_helper')

for i, node in enumerate(self._onnx_model.graph.node):
if node in self._constant_nodes:
for output in node.output:
new_attr = onnx.helper.make_attribute(
'value',
np_helper.from_array(self._outputs_infer[output], name=output)
)

new_node = onnx.helper.make_node(
op_type='Constant',
inputs=list(),
outputs=[output],
name='_'.join(('node', output))
)
new_node.attribute.extend([new_attr])
self._insert_node(self._onnx_model.graph.node, i + 1, new_node)
del self._onnx_model.graph.node[i]

@staticmethod
def _insert_node(repeated_container, index, node):
"""Insert node into onnx model."""

repeated_container.extend([repeated_container[-1]])
for i in reversed(range(index + 1, len(repeated_container) - 1)):
repeated_container[i].CopyFrom(repeated_container[i - 1])
repeated_container[index].CopyFrom(node)

+ 0
- 691
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -1,691 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import os
import re
import warnings
from copy import deepcopy
from importlib import import_module
from typing import Dict, NoReturn

import numpy as np

from mindinsight.conf import settings
from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .torch_utils import set_opset_version
from ..common.utils import fetch_output_from_onnx_model

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME
from ..constant import LEFT_BUCKET, RIGHT_BUCKET
from ...common.exceptions import ModelNotSupportError

NONE_SCOPE_OP = {
"onnx::Add": "Add",
"onnx::Flatten": "Flatten",
"onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze",
"onnx::Split": "Split",
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean",
"onnx::Resize": "Resize",
"onnx::Pad": "Pad"
}

CONSTANT_NODES_PATTERN = {
"onnx::Resize": [
'onnx::Concat',
'onnx::Slice',
'onnx::Cast',
'onnx::Concat',
'onnx::Unsqueeze',
'onnx::Floor',
'onnx::Mul',
'onnx::Cast',
'onnx::Gather',
'onnx::Shape'
],
"onnx::Pad": [
'onnx::Cast',
'onnx::Concat',
'onnx::ConstantOfShape',
'onnx::Sub',
'onnx::Mul',
'onnx::Div',
'onnx::Gather',
'onnx::Shape',
'onnx::Unsqueeze',
'onnx::Slice',
'onnx::Reshape',
'onnx::Transpose'
],
"onnx::Constant": list()
}


def normalize_scope_name(node, scope_name_dict):
"""
Rename scope name into uniform.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Returns:
str, normalized scope name.
"""
global NONE_SCOPE_OP

scope_name = node.scopeName()
if not scope_name:
name = [retrieve_scope_name(node, scope_name_dict)]
else:
name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = []
for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0]
left = segment.find(LEFT_BUCKET)
right = segment.find(RIGHT_BUCKET)
if left != -1:
if segment[left + 1: right].isdigit():
scopes.append(f"{segment[:left]}_{segment[left + 1: right]}")
else:
scopes.append(segment[left + 1: right])
else:
scopes.append(segment)
if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()])
scopes = [s for s in scopes if s]
node_id = PyTorchGraph.get_node_id(node)
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}"


def retrieve_scope_name(node, scope_name_dict):
"""
Retrieve scope name from input nodes.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Return:
str: Scope name.
"""
node_content = \
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:])
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")

scope_name_ipt_nodes = list()
for node_input in node_inputs:
if not scope_name_dict.get(node_input, None):
continue
scope_name_ipt_nodes.append(scope_name_dict[node_input])

scope_name_split = list()
for idx, _ in enumerate(scope_name_ipt_nodes):
if not scope_name_split:
scope_name_split = scope_name_ipt_nodes[idx]
else:
scope_name_split = [
sub_scope_name
for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx]
]
scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split)
return scope_name


class PyTorchGraph(Graph):
"""
Define PyTorch graph.

Args:
model (Module): PyTorch model.
sample_shape (tuple): Input shape of the model.

"""

def __init__(self, model, sample_shape: tuple):
super(PyTorchGraph, self).__init__(model=model)

from .torch_utils import unique_state_dict

self._params_dict = unique_state_dict(model)
self._original_shape = list()
self._nodes = list()
self._constant_nodes = list()
self._dynamic_nodes = list()
self._has_eliminated_nodes = False
self._file_graph_onnx = os.path.join(
settings.WORKSPACE, 'log/mindconverter/'
)

self.build(sample_shape)

@staticmethod
def _check_input_shape(input_shape):
"""
Check input shape.

Args:
input_shape (tuple): Input tensor shape.

"""
if not input_shape:
err_msg = "`input_shape` can not be None."
log.error(err_msg)
raise ValueError(err_msg)

for item in input_shape:
if not isinstance(item, int):
err_msg = "Only support model with one input now, " \
"and each shape value in `input_shape` should be int."
log.error(err_msg)
raise ValueError(err_msg)

@staticmethod
def _extract_shape(shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
if "," not in shape:
return []

shape_arr = []
for s in shape.split(","):
s = s.strip()
if not s:
return []
if ":" in s:
s = s.split(":")[0]
s = s.replace("!", "")
if not s.isdigit():
return []
shape_arr.append(int(s))
return shape_arr

def _trace_torch_graph(self, input_shape):
"""
Trace torch computational graph.

Args:
input_shape (tuple): Shape.

Returns:
object, pytorch graph.
"""
import torch
from torch.onnx import OperatorExportTypes
from .torch_utils import OverloadTorchModuleTemporarily
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

warnings.simplefilter("ignore")

batched_sample = create_autograd_variable(torch.rand(*input_shape))

try:
try:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
set_opset_version(ONNX_OPSET_VERSION)
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError as error:
log.error(str(error))
log.exception(error)
raise error

def build(self, input_shape):
"""
Build graph tree.

Args:
input_shape (tuple): Input shape of model.

"""
self._check_input_shape(input_shape)
self._original_shape = input_shape
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())
self._nodes = nodes

scope_name_dict = dict()

self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes)

for node in nodes:
output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs()))
if output_name in self._dynamic_nodes:
continue

node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1]
output_shape = self._extract_shape(output_shape_str)
weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)]', node.scopeName())
)

if self._constant_nodes:
node_weight = self._replace_constant_node(node)
else:
node_weight = {}

for scope, weight in self._params_dict.items():
split_scope = scope.split('.')
if '.'.join(split_scope[:-1]) == weight_scope:
node_weight[split_scope[-1]] = weight

if not node_weight and node.kind() == 'onnx::Conv':
weight_names = list(self._params_dict.keys())
node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_input_names:
if int(node_input_name) > len(weight_names):
continue
weight = self._params_dict[weight_names[int(node_input_name) - 1]]
node_weight[weight_names[int(node_input_name) - 1]] = weight

self._shape_dict[node_name] = output_shape
self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight)
self._nodes_record[node_name] = node_name

for node_input in list(node.inputs()):
if self._extract_input_name(node_input) in self._constant_nodes:
continue
# Connect input node and src node.
nd_id = PyTorchGraph.get_node_id(node_input.node())
nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \
node_input.node().scopeName()

if nd_id and nd_scope_name:
node_input_name = normalize_scope_name(
node_input.node(), scope_name_dict
)
self.build_connection(node_input_name, node_name)

self._unmerge_multi_ipt_opt_script()

super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)

@staticmethod
def _extract_node_name(node):
"""Extract node name for node."""
result = re.match(r"\d+", str(node))
if result:
return result.group(0)
return None

@staticmethod
def _extract_input_name(node_input):
"""Extract node input name from node input."""
node_input_name = str(node_input).split('defined in')[0].strip()
return node_input_name

def _get_constant_nodes(self, nodes):
"""
Get constant nodes to be eliminated.

Args:
nodes (Nodes): Nodes in torch._C.Graph.

Returns:
Union(dict, list), output of constant_input_node_name and dynamic nodes name.
"""
constant_input_nodes = list()
dynamic_nodes = list()
for node in nodes:
if node.kind() == 'onnx::Resize':
self._has_eliminated_nodes = True
constant_input_node, dynamic_node = self._generate_inputs_of(node)
constant_input_nodes += constant_input_node
dynamic_nodes += dynamic_node

outputs = dict()
if self._has_eliminated_nodes:
torch = import_module('torch')
device_target = 'cuda' if torch.cuda.is_available() else 'cpu'
dump_input = torch.randn(*self._original_shape, device=device_target)
temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx,
'.graph_onnx.onnx'))

symbolic_helper = import_module('torch.onnx.symbolic_helper')
export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version')
try:
torch.onnx.export(self.model.to(device_target), dump_input,
temp_onnx_path, opset_version=export_onnx_opset_version)

outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape)
finally:
if os.path.exists(temp_onnx_path):
os.remove(temp_onnx_path)

return outputs, dynamic_nodes

def _generate_inputs_of(self, node):
"""
Generate inputs of certain node.

Args:
node (Node): Node of torch._C.Graph.

"""
pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None)
constant_input_nodes = list()
dynamic_nodes = list()
if not isinstance(pattern_op_lst, list):
return constant_input_nodes, dynamic_nodes
if not pattern_op_lst:
dynamic_nodes += self.get_node_id(node)
return constant_input_nodes, dynamic_nodes

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]

for node_input_name in node_inputs_name:
node_name_path = self._search_node_path(node_input_name, pattern_op_lst)
if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape':
constant_input_nodes.append(node_input_name)
dynamic_nodes += node_name_path

return constant_input_nodes, dynamic_nodes

def _search_node_path(self, node_name, pattern_op_lst):
"""
Search node path based on pattern_op_list.

Args:
node_name (str): Node name.
pattern_op_lst (list): Pattern list of certain operator.

Returns:
list[str]: node names in pattern.
"""
node_type_lst = list()
node_name_lst = list()
node = self._get_node_from_graph(node_name)

if node_name == MODEL_INPUT_NAME:
return node_name_lst

if node.kind() not in pattern_op_lst:
return node_name_lst

node_type_lst.append(node.kind())
node_name_lst.append(node_name)

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_inputs_name:
node_name_lst += self._search_node_path(node_input_name, pattern_op_lst)

return node_name_lst

def _get_node_from_graph(self, node_name):
"""Get torch._C.Node from torch._C.Graph."""
for idx, node in enumerate(self._nodes):
node_id = ', '.join(self.get_node_id(node))
if node_id == node_name:
return self._nodes[idx]
return None

@staticmethod
def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape):
"""
Infer onnx model to get outputs of inner nodes.

Args:
file_graph_onnx (str): File path of onnx.
infer_outputs (list): Outputs for infer.
infer_inputs_shape (list): Input shape for infer.

"""
onnx = import_module('onnx')
tensor_proto = getattr(onnx, 'TensorProto')
onnx_model = onnx.load(file_graph_onnx)

for onnx_node in onnx_model.graph.node:
if set(onnx_node.output).issubset(set(infer_outputs)):
onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output])

input_onnx = onnx_model.graph.input[0]
node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type)
if node_type != 'FLOAT':
raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. "
f"Please report issue to us if extra input type is needed.")

input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}
outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs)

return outputs

def _replace_constant_node(self, node):
"""Replace constant node."""
node_weight = dict()
for node_input in list(node.inputs()):
node_input_name = self._extract_input_name(node_input)
if node_input_name in self._constant_nodes:
node_weight[node_input_name] = self._constant_nodes[node_input_name]
return node_weight

def _collect_ipt_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.

Args:
input_shape (tuple): Input shape.

"""
input_node = InputNode(input_shape)
input_node_name = "{}InputNode"
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)
input_node.set_scope_name(node.scope_name)
node.precursor_nodes.insert(0, ipt_nd_name)
input_node.set_successor_nodes(node_name)
self._shape_dict[ipt_nd_name] = input_node.output_shape

if not self._shape_dict[node_name]:
self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE

ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)

self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape

def _generate_module(self):
"""Generate modules."""
module_dict = dict()
for node_key, _ in self._nodes_collection.items():
node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE)
if len(node_key_in_scope) < MIN_SCOPE_LENGTH:
continue

for idx in range(1, len(node_key_in_scope)):
node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx])
node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1])
if not module_dict.get(node_key_module, None):
module_dict[node_key_module] = {node_name}
else:
module_dict[node_key_module].add(node_name)

return module_dict

def _check_multi_ipt_opt(self):
"""Check whether multi-input exists."""
module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set()
for node_name in nodes_per_module:
if re.search(r"[\d]+[&][\d]+", node_name):
self._is_multi_opt_graph = True
return True

node = self._nodes_collection.get(node_name, None)
if node:
prcs_nodes = node.precursor_nodes
else:
continue

for prcs_node in prcs_nodes:
if prcs_node not in nodes_per_module:
prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1])
if prcs_node_module not in nodes_per_module:
prcs_nodes_out_from_module.add(prcs_node)

if len(prcs_nodes_out_from_module) > 1:
return True

return False

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt_opt() or self._has_eliminated_nodes:
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes

node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)]
node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0],
scsr_node.split(SEPARATOR_IN_SCOPE)[-1]))
for scsr_node in deepcopy(scsr_nodes)]

reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._nodes_collection[node_key]
self._nodes_collection[reduce_node_key] = node_inst

for node_key, shape in deepcopy(self._shape_dict).items():
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._shape_dict[node_key]
self._shape_dict[reduce_node_key] = shape

def sub_graph_merging(self):
"""
Merge split operation into one.
"""
raise NotImplementedError()

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.

Args:
src (str): Source node name.
tgt (str): Target node name.

"""
# If src and tgt are the same node, src not in node_collection or
# tgt not in node_collection, then skip this edge.
if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
if src.split(':')[0] not in self._nodes_collection:
log.warning("Graph construct a self-loop node %s. Ignored.", src)
return
if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt)
if src not in self._nodes_collection[tgt].precursor_nodes:
self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src)

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
"""
Load checkpoint.

Args:
ckpt_path (str): Checkpoint file path.

Returns:
dict, weights in model.
"""

@staticmethod
def load_metadata(**kwargs):
"""
Load graph metadata.
"""
err_msg = "class `PyTorchGraph` has not implemented " \
"`load_metadata()`."
log.error(err_msg)
raise NotImplementedError(err_msg)

@staticmethod
def load_graph(graph_path: str, **kwargs):
"""
Load graph.

Args:
graph_path (str): Graph path.

Returns:
object, pytorch model.
"""
torch_model = PyTorchGraphParser.parse(graph_path)
return torch_model

@staticmethod
def get_node_id(node):
"""
Get node id using regular expr.

Args:
node (Node): PyTorch node.

Returns:
str, node id.
"""
node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_id = re.findall(r"[%](.*?) [:]", node_title)
return node_id

+ 0
- 236
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -1,236 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
import re

from .base import GraphNode
from ..common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT


class PyTorchGraphNode(GraphNode):
"""
PyTorch graph node.

Args:
node (torch._C.Node): Node in raw PyTorch graph.

"""

_type_frozen = False
_module_name_frozen = False

def __init__(self, node=None, weight=None):
super(PyTorchGraphNode, self).__init__(node=node)
self._op_params = self._get_raw_params(node)
self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None
self._weight = weight
self._ipt_var_names, self._opt_var_names \
= self._extract_ipt_opt_var_names() if node else (list(), list())

def _extract_ipt_opt_var_names(self):
"""Extract ipt and opt var names."""
node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(
str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]
)
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")
node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_outputs = re.findall(r"[%](.*?) [:]", node_title)
return node_inputs, node_outputs

def clear_args_of_declaration(self):
"""
Clear `self._args_in_code`.
"""
self._args_in_code = dict()

def _get_arg_name(self, arg, variable_name):
"""
Get arg name.

Args:
arg (str): Generate arg name.

Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{variable_name}"

@property
def is_in_multi_opt_graph(self):
return self._is_in_multi_opt_graph

@is_in_multi_opt_graph.setter
def is_in_multi_opt_graph(self, multi_opt_state):
self._is_in_multi_opt_graph = multi_opt_state

@property
def hash_key(self):
"""
Return unique hash key of current node.

Returns:
str, hash key.
"""
if self._node_type not in {NodeType.CLASS.value,
NodeType.FUNC.value,
NodeType.MODULE.value}:
self._hash_key = self._op_name.lower()
return self._hash_key

@hash_key.setter
def hash_key(self, h):
"""
Setter of hash key.

Args:
h (str): Key.

"""
self._hash_key = h

@property
def op_name(self):
"""
Op name in torch.

Returns:
str, op name.
"""
return self._op_name

@op_name.setter
def op_name(self, name):
"""
Setter of op name.

Args:
name(str): op_name.

"""
self._op_name = name

@property
def real_name(self):
return

def add_input_and_output_shape(self, input_shape, output_shape):
"""
Add the node input shape.

Args:
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

"""
self._ipt_shape = input_shape
self._opt_shape = output_shape

def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment):
"""
Generate statements.

Args:
variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input.
output_var (list): Output variable names in construct.
code_fragment (CodeFragment): CodeFragment instance.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = code_fragment.operation

args = self.args_in_code
settings = code_fragment.code_setting

if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation):
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct, settings)
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = ipt_args_in_construct

if SEPARATOR_IN_ONNX_OP in operator:
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")

declare = f"self.{variable_name} = {operator}({expr})"
call = f"{', '.join([output for output in output_var])}" \
f" = self.{variable_name}({ipt_args_settings_in_construct})"

return declare, call

def to_ir(self):
"""
No need to implement for now.
"""
raise NotImplementedError

def _get_raw_params(self, node):
"""
Get params in onnx.

Args:
node (Any): Node.

Returns:
dict, raw params.
"""
from .torch_utils import getitem_of_node

raw_params = dict()

if not node:
return raw_params

for k in node.attributeNames():
raw_params[k] = getitem_of_node(node, k)
return raw_params

def replace_with_arg(self, src_arg, tgt_arg):
"""
Replace actual parameter with formal parameter.

Args:
src_arg (str): Original arg name.
tgt_arg (str): Target arg name.

"""
self._args_in_code[src_arg] = tgt_arg

@staticmethod
def _extract_var_name(scope_name: str):
"""
Extract variable name from scope name.
"""
if not scope_name:
return None
var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower()
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

+ 59
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ from importlib import import_module


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser from .base import GraphParser
from .optimizer import OnnxSimplify
from ...common.exceptions import ModelNotSupportError from ...common.exceptions import ModelNotSupportError




@@ -38,7 +39,6 @@ class PyTorchGraphParser(GraphParser):
Returns: Returns:
object, torch model. object, torch model.
""" """
torch = import_module("torch")


if not os.path.exists(model_path): if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with " error = FileNotFoundError("`model_path` must be assigned with "
@@ -47,14 +47,66 @@ class PyTorchGraphParser(GraphParser):
raise error raise error


try: try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
onnx_model_sim = cls._convert_pytorch_graph_to_onnx(
model_path, kwargs['sample_shape'], opset_version=11)
return onnx_model_sim
except ModuleNotFoundError: except ModuleNotFoundError:
error_msg = "Cannot find model scripts in system path, " \ error_msg = "Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly." "set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg) error = ModuleNotFoundError(error_msg)
raise error raise error


return model
@staticmethod
def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None):
"""
Convert Pytorch model to ONNX model.

Args:
model_path (str): Path to the Pytorch model.
sample_shape (tuple): Input shape to generate onnx model.
opset_version (int): Op set version of onnx.
"""

torch = import_module('torch')
has_cuda = torch.cuda.is_available()
if has_cuda:
model = torch.load(f=model_path).cuda()
dump_input = torch.randn(*sample_shape, device='cuda')
else:
model = torch.load(f=model_path, map_location="cpu")
dump_input = torch.randn(*sample_shape, device='cpu')

if isinstance(model, torch.nn.DataParallel):
raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.')

torch_onnx = import_module('torch.onnx')
operator_export_types = getattr(torch_onnx, 'OperatorExportTypes')
utils = import_module('torch.onnx.utils')
model_to_graph = getattr(utils, '_model_to_graph')

symbolic_helper = import_module('torch.onnx.symbolic_helper')
default_onnx_opset_version = getattr(symbolic_helper, '_default_onnx_opset_version')
set_opset_version = getattr(symbolic_helper, '_set_opset_version')
set_operator_export_type = getattr(symbolic_helper, '_set_operator_export_type')
if not opset_version:
opset_version = default_onnx_opset_version

operator_export_type = operator_export_types.ONNX
set_opset_version(opset_version)
set_operator_export_type(operator_export_type)

graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True)
export_onnx = getattr(graph, '_export_onnx')
proto, _ = export_onnx(
params_dict, opset_version, dict(), False,
operator_export_type, True, False, dict(),
True, False)

onnx = import_module('onnx')
onnx_model = onnx.load_model_from_string(proto)

onnx_simplify = OnnxSimplify()
onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape)

return onnx_model_sim

+ 0
- 105
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -1,105 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define pytorch tracer context manager."""
import importlib

from torch.nn import Module
from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem
from torch.onnx.symbolic_helper import _set_opset_version


SCRIPT_METHOD = getattr(importlib.import_module("torch._C"),
"ScriptMethod")
onnx_tracer = _trace
getitem_of_node = _node_getitem
set_opset_version = _set_opset_version


def unique_state_dict(model):
"""
Wrapper of torch.jit._unique_state_dict.

Args:
model (Module): Torch model.

Returns:
dict, params.
"""
from torch.jit import _unique_state_dict

return _unique_state_dict(model)


def create_autograd_variable(tensor):
"""
Create autograd variable to trace the whole graph.

Args:
tensor (torch.Tensor): Tensor.

Returns:
torch.autograd.Variable, variable.
"""
variable = getattr(importlib.import_module("torch.autograd"), "Variable")
return variable(tensor, requires_grad=False)


class OverloadTorchModuleTemporarily:
"""
Fix bugs in new version of pytorch.
PyTorch official solution.
"""

def __init__(self):
self.backup = None

def __enter__(self):
def _tracing_name(traced_module, tracing_state):
traced_module_stack = getattr(tracing_state, "_traced_module_stack")
if not traced_module_stack:
return None
module = traced_module_stack[-1]
for name, child in module.named_children():
if child is traced_module:
return name
return None

def _slow_forward(self_, *inputs, **kwargs):
tracing_state = getattr(importlib.import_module("torch._C"),
"_get_tracing_state")()
if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD):
return self_.forward(*inputs, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = _tracing_name(self_, tracing_state)
get_name_func = getattr(self_, "_get_name")
if name:
tracing_state.push_scope('%s[%s]' % (get_name_func(), name))
else:
tracing_state.push_scope(get_name_func())
tracing_state._traced_module_stack.append(self_)
try:
result = self_.forward(*inputs, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result

self.backup = getattr(Module, "_slow_forward")
setattr(Module, '_slow_forward', _slow_forward)

def __exit__(self, exc_type, exc_val, exc_tb):
setattr(Module, '_slow_forward', self.backup)

tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py → tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Test name manager module.""" """Test name manager module."""
from unittest import TestCase from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import NameMgr, GlobalVarNameMgr, \
global_op_namespace global_op_namespace





+ 0
- 15
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -1,15 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit test for mindconvert.graph_based_converter.hierarchical_tree interface."""

+ 0
- 177
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py View File

@@ -1,177 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test hierarchical tree module."""
import os
import shutil
from unittest import mock

import pytest

from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.constant import NodeType

from tests.ut.mindconverter.graph_based_converter.conftest import TEST_BASE_PATH


class TestHierarchicalTree:
"""Test the class of HierarchicalTree."""

def test_tree_identifier(self):
"""Test tree_identifier"""
tree = HierarchicalTree()
assert isinstance(tree.tree_identifier, str)

@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_insert(self, get_raw_params):
"""Test insert"""
get_raw_params.return_value = []
tree = HierarchicalTree()
pt_node = PyTorchGraphNode()
tree.insert(pt_node, 'ResNet')
assert tree.root == 'ResNet'

def test_remove(self):
"""Test remove function."""
tree = HierarchicalTree()
tree.create_node(
tag='node_root',
identifier='root',
parent=None,
data=None
)
node = tree.get_node('root')
tree.remove(node)
assert tree.root is None

@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_shrink(self, get_raw_params):
"""Test shrink function."""
params = {'root': {},
'root/child0': {},
'root/child0/child1': {}}
tree = self._create_tree(get_raw_params=get_raw_params, params=params)
node = tree.get_node('root/child0')
tree.shrink(node)
assert tree.leaves()[0].tag == 'child1'

@pytest.mark.parametrize('params', [{
'tree_params': {'root': {'op_name': 'Root',
'precursor_nodes': [],
'successor_nodes': ['root/relu'],
'node_type': NodeType.MODULE.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]},
'root/relu': {'op_name': 'onnx::Relu',
'precursor_nodes': ['root'],
'successor_nodes': ['root/unknown'],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 3, 224, 224]},
'root/unknown': {'op_name': 'onnx::Unknown',
'precursor_nodes': ['root/relu'],
'successor_nodes': [],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]}},
'report_dir': 'report_folder'
}, {
'tree_params': {'root': {'op_name': 'Root',
'precursor_nodes': [],
'successor_nodes': ['root/relu'],
'node_type': NodeType.MODULE.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]},
'root/relu': {'op_name': 'onnx::Relu',
'precursor_nodes': ['root'],
'successor_nodes': ['root/unknown'],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 3, 224, 224]},
'root/unknown': {'op_name': 'onnx::Unknown',
'precursor_nodes': ['root/relu'],
'successor_nodes': [],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]}},
'report_dir': None
}])
@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_save_source_file(self, get_raw_params, params):
"""Test save_source_file function."""
tree_params = params['tree_params']
out_folder = 'out_folder'
report_folder = params['report_dir']
model_name = 'model_name'
mapper = ONNXToMindSporeMapper()

tree = self._create_tree(get_raw_params=get_raw_params, params=tree_params)
tree.save_source_files(out_folder, mapper, model_name, report_folder)

out_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py"))
report_folder_test = report_folder if report_folder else out_folder
report_path = os.path.realpath(
os.path.join(report_folder_test, f"report_of_{model_name}.txt"))
try:
assert os.path.exists(out_path)
assert os.path.exists(report_path)
with open(out_path, 'r') as out_r:
code = out_r.read()
assert 'nn.ReLU' in code
assert 'onnx.Unknown' in code
with open(report_path, 'r') as report_r:
report = report_r.read()
assert "[UnConvert] 'onnx::Unknown' didn't convert." in report
assert "Converted Rate: 50.00%." in report
finally:
shutil.rmtree(out_folder)
if report_folder:
shutil.rmtree(report_folder)

@staticmethod
def _create_node(key, val, weight, input_shape, output_shape):
"""Create node."""
node = PyTorchGraphNode(weight=weight)
node.add_input_and_output_shape(input_shape, output_shape)
node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key
node.op_name = val['op_name'] if val.get('op_name') else None
node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else []
node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else []
node.node_type = val['node_type'] if val.get('node_type') else None
return node

@staticmethod
def _create_tree(get_raw_params, params):
"""Create tree."""
tree = HierarchicalTree()
for key, val in params.items():
input_shape = val['input_shape'] if val.get('input_shape') else []
output_shape = val['output_shape'] if val.get('output_shape') else []
get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict()
weight = val['weight'] if val.get('weight') else None

node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape)

tree.create_node(
tag=node.tag,
identifier=key,
parent='/'.join(key.split('/')[:-1]) if len(key.split('/')) > 1 else None,
data=node
)
return tree

+ 2
- 2
tests/ut/mindconverter/graph_based_converter/mapper/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Unit test for mindconvert.graph_based_converter.mapper interface."""
"""Unit test for mindconverter.graph_based_converter.mapper interface."""

+ 5
- 6
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ import pytest


from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from tests.utils import mindspore




class TestMappers: class TestMappers:
@@ -30,7 +29,7 @@ class TestMappers:
'group': 1, 'group': 1,
'pads': [1, 2, 3, 4], 'pads': [1, 2, 3, 4],
'strides': [1, 1]}, 'strides': [1, 1]},
'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}},
'weights': {'weight': np.zeros((64, 3, 1, 1), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Conv2d', 'expected_output': {'converter_name': 'nn.Conv2d',
'converted_params': {'in_channels': 3, 'converted_params': {'in_channels': 3,
'out_channels': 64, 'out_channels': 64,
@@ -47,7 +46,7 @@ class TestMappers:
'group': 1, 'group': 1,
'pads': [0, 0, 0, 0], 'pads': [0, 0, 0, 0],
'strides': [1, 1]}, 'strides': [1, 1]},
'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}},
'weights': {'weight': np.zeros((64, 3, 2, 2), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Conv2d', 'expected_output': {'converter_name': 'nn.Conv2d',
'converted_params': {'in_channels': 3, 'converted_params': {'in_channels': 3,
'out_channels': 64, 'out_channels': 64,
@@ -61,8 +60,8 @@ class TestMappers:
}, { }, {
'input': {'op_name': 'onnx::Gemm', 'input': {'op_name': 'onnx::Gemm',
'params': dict(), 'params': dict(),
'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)),
'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}},
'weights': {'weight': np.zeros((10, 3), dtype=np.int32),
'bias': np.zeros((10, 1), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Dense', 'expected_output': {'converter_name': 'nn.Dense',
'converted_params': {'in_channels': 3, 'converted_params': {'in_channels': 3,
'out_channels': 10, 'out_channels': 10,


Loading…
Cancel
Save