Browse Source

!1103 Rewrite pytorch parser module to use onnx model module.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
d85a5f0a17
28 changed files with 312 additions and 2200 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/name_mgr.py
  2. +3
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +20
    -25
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +4
    -3
      mindinsight/mindconverter/graph_based_converter/generator/args_translator.py
  5. +2
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +2
    -2
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +3
    -11
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  8. +0
    -89
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  9. +0
    -796
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  10. +10
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  11. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  12. +6
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  13. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  14. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  15. +7
    -12
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  16. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  17. +18
    -5
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  18. +19
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  19. +144
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  20. +0
    -691
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  21. +0
    -236
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  22. +59
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  23. +0
    -105
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py
  24. +2
    -2
      tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py
  25. +0
    -15
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  26. +0
    -177
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py
  27. +2
    -2
      tests/ut/mindconverter/graph_based_converter/mapper/__init__.py
  28. +5
    -6
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py

mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py → mindinsight/mindconverter/graph_based_converter/common/name_mgr.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

+ 3
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -194,6 +194,9 @@ def convert_bytes_string_to_string(bytes_str):

def get_framework_type(model_path):
"""Get framework type."""
if model_path.endswith('.onnx'):
return FrameworkType.PYTORCH.value

try:
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:


+ 20
- 25
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,10 +24,12 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver
save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
@@ -62,6 +64,7 @@ def torch_installation_validation(func):
"""

def _f(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
@@ -93,6 +96,7 @@ def torch_installation_validation(func):
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)

return _f
@@ -182,6 +186,7 @@ def _extract_model_name(model_path):
@SourceFilesSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
"""
PyTorch to MindSpore based on Graph.
@@ -189,26 +194,18 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
Args:
graph_path (str): Graph file path.
sample_shape (tuple): Input shape of the model.
input_nodes (str): Input node(s) of the model.
output_nodes (str): Output node(s) of the model.
output_folder (str): Output folder.
report_folder (str): Report output folder path.

"""
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)

hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)


@tf_installation_validation
@@ -230,18 +227,13 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
output_nodes(str): Output node(s) of the model.
output_folder(str): Output folder.
report_folder(str): Report output folder path.

"""
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'),
"batch_add_nodes")

# Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate()
@@ -255,7 +247,6 @@ def main_graph_base_converter(file_config):

Args:
file_config (dict): The config of file which to convert.

"""
graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path)
@@ -263,8 +254,12 @@ def main_graph_base_converter(file_config):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")

if frame_type == FrameworkType.PYTORCH.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value:


+ 4
- 3
mindinsight/mindconverter/graph_based_converter/generator/args_translator.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -213,10 +213,11 @@ class ArgsTranslationHelper:
Returns:
list, name of args to be formal.
"""
ret = list()
if len(args_translators) < 2:
# only one args_translator provided, no formal args.
return None
ret = []
return ret
base_args_t = args_translators[0]
for arg_name, arg_val in base_args_t.actual_args.items():
for args_t in args_translators[1:]:


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator



+ 2
- 2
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ from ..common.utils import get_dict_key_by_value
from .args_translator import ArgsTranslation
from ..common.code_fragment import ModuleFragment
from ..common.global_context import GlobalContext
from ..hierarchical_tree.name_mgr import LocalVarNameMgr
from ..common.name_mgr import LocalVarNameMgr


class ModuleStruct:


+ 3
- 11
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ from collections import OrderedDict
from .scope_utils import Scope
from .args_translator import ArgsTranslation
from ..common.code_fragment import CodeFragment
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ..constant import InputType
@@ -110,11 +109,6 @@ class NodeStruct:
self.graph_node_ref = gn
self.scope_name = gn.scope_name

def _update_from_pytorch_gn(self, gn: PyTorchGraphNode):
"""Update basic info from PyTorchGraphNode."""
self.node_type = "PyTorchGraphNode"
self._update_basics_from_gn(gn)

def _update_from_onnx_gn(self, gn: OnnxGraphNode):
"""Update basic info from OnnxGraphNode."""
self.node_type = "OnnxGraphNode"
@@ -177,9 +171,8 @@ class NodeStruct:
arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.
force_ready (bool): Force this NodeStruct is ready to generate.
"""
if isinstance(arg, PyTorchGraphNode):
self._update_from_pytorch_gn(arg)
elif isinstance(arg, OnnxGraphNode):

if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg)
elif isinstance(arg, (dict, OrderedDict)):
self._update_from_mapper(arg)
@@ -246,7 +239,6 @@ class NodeStruct:
"""Return the output variable name of current node."""
return "{}_opt".format(self.ms_var_name).lower()


@property
def args_translator(self):
"""Return the args translator of this Node."""


+ 0
- 89
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -1,89 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hierarchical tree module."""
__all__ = ["HierarchicalTreeFactory"]

import re

from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode

from ...common.exceptions import NodeInputMissingError, TreeNodeInsertError


def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
"""
Rename the node name by combining scope name and its original name.

Args:
node (OnnxGraphNode): OnnxGraphNode instance.
node_name (str): node name saved in Graph.

Returns:
str, re-formatted node name.
"""
scope_name = node.scope_name
new_name = None
regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name)
parent = match.group("parent")
node_name = '$' + node_name.replace('/', '::') + '$'

if scope_name:
new_name = parent + node_name
if new_name:
return new_name
return node_name


class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@classmethod
@TreeNodeInsertError.check_except("Tree node inserts failed.")
def create(cls, graph):
"""
Factory method of hierarchical tree.

Args:
graph: Graph obj.

Returns:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_scope_name = dict()
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_input = graph.get_input_shape(node_name)
node_output = graph.get_output_shape(node_name)
if node_input != 0 and not node_input:
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
error = NodeInputMissingError(err_msg)
log.error(str(error))
raise error
if isinstance(node_inst, OnnxGraphNode):
node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name)
node_scope_name[node_name] = node_name_with_scope
node_name = node_name_with_scope

node_inst.add_input_and_output_shape(node_input, node_output)
tree.insert(node_inst, node_name)

if node_scope_name:
return tree, node_scope_name
return tree

+ 0
- 796
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -1,796 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define hierarchical tree."""
import os
import stat
from copy import deepcopy
from typing import NoReturn, Union
from queue import Queue

from yapf.yapflib.yapf_api import FormatCode
from treelib import Tree, Node

from mindinsight.mindconverter.common.log import logger as log

from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS
from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
from ..report_generator import ReportGenerator
from ...common.exceptions import ReportGenerationError, ScriptGenerationError, NodeInputTypeNotSupportError


class HierarchicalTree(Tree):
"""Define hierarchical tree."""
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR
modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR

_root_created = False
ROOT_LEVEL = 0

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()

def __init__(self):
super(HierarchicalTree, self).__init__()
self._hierarchical_order = dict()
# Manage mapping of unique key and module name.
self._merged_module = dict()
# Manage mapping of unique key and module args.
self._merged_module_args = dict()
# Record creation of module with unique key.
self._created_module = dict()
# Manage module name to used.
self._module_mgr = ModuleNameMgr()
# Manage variable name in a module.
self._vars_mgr_in_module = dict()
self._module_vars = dict()
# scope name mapping record for easy node searching
self._scope_name_map = dict()
self.code_fragment_recorder = dict()

@property
def tree_identifier(self):
"""
Return identifier of tree.

Returns:
tree, id of tree.
"""
return self.identifier

def get_node(self, nid):
"""Override get_node method to support tf ver. generated scope."""
if nid is None or not self.contains(nid):
if self._scope_name_map and nid in self._scope_name_map:
nid = self._scope_name_map.get(nid)
else:
return None
return self._nodes[nid]

def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str):
"""
Insert node into hierarchical tree.

Args:
node_name (str): Node name.
node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted.

"""
scopes = node_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])
try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \
if parent else scope
if self.contains(try_parent):
# Whether current node existed.
parent = try_parent

if not parent and not self._root_created:
# If no root node, then create it and mark it.
parent = None
self._root_created = True
elif not parent and self._root_created:
# Already have root node, skip it.
continue

if not self.contains(identifier):
# Insert node into tree.
if isinstance(node, OnnxGraphNode):
tgt_node = node if idx == len(
scopes) - 1 else OnnxGraphNode()
else:
tgt_node = node if idx == len(
scopes) - 1 else PyTorchGraphNode()
tgt_node.successor_nodes = node.successor_nodes
tgt_node.precursor_nodes = node.precursor_nodes
tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1
else NodeType.MODULE).value
tgt_node.variable_name = self._get_var_name(identifier)
self.create_node(
tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0],
identifier=identifier,
parent=parent,
data=tgt_node
)

def remove(self, node: Node, keep_sub=False):
"""
Remove node into hierarchical tree.

Args:
node (Node): Node to be removed.
keep_sub (bool): Whether keep sub-tree.

"""
if not keep_sub:
self.remove_node(node.identifier)
return

def shrink(self, node: Node):
"""
Shrink sub-tree into one node.

Use child node to replace its ancestor.

Args:
node (Node): List of nodes to be merged.

"""
node_name = node.identifier
parent_node = self[node.predecessor(self.tree_identifier)]
# Keep successors of parent.
brothers = deepcopy(parent_node.successors(self.tree_identifier))
# Because shrink occurs when node has only one child,
# so we take index-0.
child = node.successors(self.tree_identifier)[0]
self.move_node(source=child,
destination=node.predecessor(self.tree_identifier))
self.remove(node)
brothers[brothers.index(node_name)] = child
parent_node.set_successors(brothers, tree_id=self.tree_identifier)

def save_source_files(self, out_folder: str, mapper: Mapper,
model_name: str,
report_folder: str = None,
scope_name_map: dict = None) -> NoReturn:
"""
Save source codes to target folder.

Args:
report_folder (str): Report folder.
mapper (Mapper): Mapper of third party framework and mindspore.
model_name(str): Name of Converted model.
out_folder (str): Output folder.
scope_name_map(str): Scope name map of tensorflow.

"""
if scope_name_map:
self._scope_name_map = scope_name_map
try:
self._adjust_structure()
code_fragments = self._generate_codes(mapper)
except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e:
log.error("Error occur when generating codes.")
raise e

save_code_file_and_report(model_name, code_fragments, out_folder, report_folder)

def _preprocess_node_args(self, node, module_key):
"""
Remove unused args.

Args:
node (Node): Node instance.
module_key (str): Nodule key.

Returns:
Node, node.
"""
if module_key in self._merged_module_args:
node = self._clear_unused_args(
node, self._merged_module_args[module_key])
else:
node.data.clear_args_of_declaration()
return node

def _postprocess_node_args(self, node, precursor_module_key):
"""
Post process args in node.

Args:
node (Node): Node instance.
precursor_module_key (str): Parent node module name.

Returns:
Node, node.
"""
if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value,
NodeType.FUNC.value}:
# If current node is class or function, then
# remove unused args in __init__.
cur_module_key = node.data.hash_key or self.hash_key(node)
if cur_module_key in self._merged_module_args:
node = self._clear_unused_args(node,
self._merged_module_args[cur_module_key])

# `self._merged_module_args` records formal args.
# We need to replace actual args.
if precursor_module_key in self._merged_module_args:
# If parent node is in `_merged_module_args`, then
# replace current node args with arg name declared
# in _merged_module_args.
for arg in node.data.args_in_code.keys():
if arg in self._merged_module_args[precursor_module_key]:
node.data.replace_with_arg(arg, arg)
return node

def _clear_unused_args(self, node, used_args):
"""
Clear unused args.

Args:
node (Node): Node.
used_args (list): Args list.

Returns:
Node, node instance.
"""
args_in_code = list(node.data.args_in_code.keys())
for arg in args_in_code:
ori_arg = arg.replace(
f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", ""
)
if ori_arg not in used_args:
node.data.args_in_code.pop(arg)
return node

@ScriptGenerationError.check_except("FormatCode run error. Check detailed information in log.")
@ReportGenerationError.check_except("Not find valid operators in converted script.")
def _generate_codes(self, mapper):
"""
Generate code files.

- 1. Generate args.
- 2. Merge module.
- 3. Pre-process node args.
- 4. Post-process child node args.
- 5. Generate class/func code.
- 6. Merge code snippets.

Args:
mapper (Mapper): Mapper of third party operation and mindspore.

Returns:
Dict, codes.
"""
code_blocks = [get_imported_module()]
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)

for depth in depths:
node_collection = self._hierarchical_order[depth]
for node_name in node_collection:
# Traverse nodes in topological order.
node = self.get_node(node_name)
# 1. Generate args for each node in this level.
if node.data.node_type == NodeType.MODULE.value:
self._create_module_args_and_vars(node, mapper)
if depth == depths[-1]:
self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "")

# Module merging based on all nodes.
self._module_merging()

for depth in depths:
node_collection = self._hierarchical_order[depth]
snippets = set()
for node_name in node_collection:
nd_inst = self.get_node(node_name)
if nd_inst.data.node_type != NodeType.MODULE.value:
continue

# Generate hash key for node.
module_key = nd_inst.data.hash_key
# Get code generation func.
func, node_type = self._fetch_func_and_type(nd_inst)

if module_key in self._created_module:
# If the module has already been created,
# then assign the created module name to current node,
# and delete unused args.
module_name = self._created_module[module_key]
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type
self._preprocess_node_args(nd_inst, module_key)
continue

module_name = nd_inst.tag

if node_type == NodeType.CLASS.value:
module_name = f"{module_name[0].upper()}{module_name[1:]}"

# After node_type and module_name is frozen,
# then it's unchangeable.
module_name = self._module_mgr.get_name(module_name)
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type

# 3. Pre-process node args.
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)):
self._postprocess_node_args(self.get_node(scsr_nd_name), module_key)
# 5. Generate code.
snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key))

code_blocks.extend(snippets)

if self._scope_name_map: # from tf. conversion
c_blocks = []
for s in code_blocks:
s = s.replace('$', '')
c_blocks.append(s)
code_blocks = c_blocks

formatted_code, _ = FormatCode("".join(code_blocks),
style_config=CodeFormatConfig.PEP8.value)

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)

return {"model": (formatted_code, report)}

def _fetch_func_and_type(self, node) -> Union[object, str]:
"""
Generate code snippet.

Args:
node (Node): Node.

Returns:
Union[object, str], code snippet func.
"""

def _is_func():
"""
The correct thought is to check whether have more than one
path in this block.
"""
nonlocal node

if node.predecessor(self.tree_identifier) is None:
return False

tgt_type = {NodeType.MODULE.value,
NodeType.FUNC.value, NodeType.CLASS.value}
md_type_lst = [self.get_node(child).data.node_type
for child in node.successors(self.tree_identifier)]
diff_set = set(md_type_lst) - tgt_type
return not diff_set

if _is_func():
return self._generate_func_snippet, NodeType.FUNC.value
return self._generate_class_snippet, NodeType.CLASS.value

def _generate_func_snippet(self, node, func_name, func_key):
"""
Generate function snippet.

Args:
node (Node): Node inst.

Returns:
str, code snippet.
"""
definition = ""

if func_key.lower() in self._merged_module_args and \
self._merged_module_args[func_key.lower()]:
definition = ", ".join(self._merged_module_args[func_key.lower()])

module_list = []
for node_name in node.successors(self.tree_identifier):
c_nd = self.get_node(node_name)
operator = self.code_fragment_recorder[c_nd.identifier].operation

if c_nd.data.node_type != NodeType.OPERATION.value:
hash_key = c_nd.data.hash_key or self.hash_key(c_nd)
if hash_key in self._created_module:
operator = self._created_module[hash_key]

args = c_nd.data.args_in_code
if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted(
self.code_fragment_recorder[c_nd.identifier].operation):
args.update({"input_shape": c_nd.data.input_shape,
"output_shape": c_nd.data.output_shape})

# Generate code statement.
expr = ", ".join(
[f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}"
for k, v in args.items()]
)
code_line = f"{operator}({expr})"
module_list.append(code_line)

body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list)
snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \
f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)"
definition = f"def {func_name}({definition}):{NEW_LINE}"

# Mark the structure has been created.
self._created_module[func_key.lower()] = func_name

return f"{definition}{snippet}{NEW_LINE * 3}"

def _generate_class_snippet(self, node, class_name, class_key):
"""
Generate class-type code snippet.

Args:
node (Node): Node.

Returns:
str, code snippet.
"""
super_call = f"super({class_name}, self).__init__()"

if class_key.lower() in self._merged_module_args and \
self._merged_module_args[class_key.lower()]:
args = f"{', '.join(self._merged_module_args[class_key.lower()])}"

class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \
f"{args}):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"
else:
class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"

init_block = []
construct_block = []

for idx, node_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(node_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx)

# support multiple construct and init block returns:
if isinstance(construct, list):
construct_block += construct
else:
construct_block.append(construct)

if isinstance(init, list):
init_block += init
else:
init_block.append(init)

class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}"
init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block)
csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block)
csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}"

cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}"

# Mark the structure has been created.
self._created_module[class_key.lower()] = class_name

return f"{cls_definition}" \
f"{class_init}" \
f"{init_body}{NEW_LINE}" \
f"{class_construct}" \
f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}"

def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx):
"""
Generate statements.

Args:
cur_nd_inst (Node): Current node instance.
pre_nd_inst (Node): Precursor node instance.
idx (int): Index of cur node.

Returns:
Tuple[str, str], declare in init and call in construct.
"""

ipt_args_in_construct = "x"
opt_arg_in_construct = ["output"]

if idx != 0:
if cur_nd_inst.data.is_in_multi_opt_graph:
ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst)
else:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph:
opt_arg_in_construct = [
f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
]
else:
opt_arg_in_construct = [
f"opt_{var_name}"
for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name
]

declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
variable_name=self.code_fragment_recorder[
cur_nd_inst.identifier].declared_var_name,
output_var=opt_arg_in_construct,
code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier])

return declare, call

@staticmethod
def _get_var_name(s):
"""
Get variable name using scope name.

Args:
s (str): String.

Returns:
str, variable name.
"""
return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0]

def _get_current_ipt_var(self, cur_nd):
""""
Get current input variable name from node_id.

Args:
cur_nd (Node): Current node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])

ipt_lst_raw = []
for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs:
ipt_lst_raw.append(f"{operation_input}")

opt_var_names_p_nds = set()
for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if p_nd.data.op_name in NO_CONVERTED_OPERATORS:
continue

opt_var_names_p_nd = set(p_nd.data.opt_var_names)
opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd)

ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)]
return ", ".join(ipt_lst)

def _find_all_previous_opt_var_(self, cur_nd, pre_nd):
"""
Find all input variable names.

Args:
cur_nd (Node): Current node.
pre_nd (Node): Precursor node.

Returns:
list, needed var names list.
"""
ipt_lst = []
if cur_nd.tag in NO_CONVERTED_OPERATORS:
return ipt_lst

for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if e not in pre_nd.successors(self.tree_identifier):
while True:
if p_nd.identifier in pre_nd.successors(self.tree_identifier):
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
break
pre_nd_name = p_nd.predecessor(self.tree_identifier)
if not pre_nd_name:
ipt_lst.append("x")
break
p_nd = self.get_node(pre_nd_name)
continue
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
return ipt_lst

def _get_previous_opt_var(self, cur_nd, pre_nd):
"""
Get needed input variable names.

Args:
cur_nd (Node): Current node.
pre_nd (Node): Precursor node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])
return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd))

def hash_key(self, node):
"""
Generate hash key for each node.

Args:
node (Node): Node.

Returns:
str, hash key.
"""
scsr_topo_order = []
for s in node.successors(self.tree_identifier):
cur_nd = self.get_node(s)
if cur_nd.data.node_type in {NodeType.MODULE.value,
NodeType.FUNC.value,
NodeType.CLASS.value}:
if cur_nd.data.hash_key:
scsr_topo_order.append(f"({cur_nd.data.hash_key})")
continue

raise ValueError("Current node doesn't have hash key.")

if cur_nd.data.hash_key:
scsr_topo_order.append(cur_nd.data.hash_key)
continue
unique_key = "->".join(scsr_topo_order)
node.data.hash_key = unique_key
return unique_key

def _module_merging(self):
"""Generate sub-module and corresponding params."""
merged_module_args = dict()
for module_key, module_args in self._merged_module.items():
if module_key not in merged_module_args:
merged_module_args[module_key] = []
# Take first element's args as base.
keys = module_args[0].keys()
for key in keys:
for i in range(1, len(module_args)):
if key in module_args[i] and module_args[0][key] != module_args[i][key]:
merged_module_args[module_key].append(key)
break
if key not in module_args[i]:
merged_module_args[module_key].append(key)
break

self._merged_module_args.update(merged_module_args)

def _create_module_args_and_vars(self, node, mapper):
"""
Create module args and variables in current node.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
# All args and value pair in current node module.
module_args = dict()
module_key = self.hash_key(node)
created = False

if module_key not in self._vars_mgr_in_module:
self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR
self._module_vars[module_key] = []
else:
created = True

# Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generation of params must behind variable assigment.
if created:
variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.tag
variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name)

code_fragment = nd_inst.data.param_transform(mapper, variable_name)
code_fragment.declared_var_name = variable_name
code_fragment.output_var_name = nd_inst.data.opt_var_names
code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names)
self.code_fragment_recorder[nd_inst.identifier] = code_fragment

module_args.update(nd_inst.data.args_in_code)

if not created:
self._module_vars[module_key].append(variable_name)

node.data.args_in_code = module_args

# Collect module args of `module_key`.
if module_key not in self._merged_module:
self._merged_module[module_key] = [deepcopy(node.data.args_in_code)]
else:
self._merged_module[module_key].append(deepcopy(node.data.args_in_code))

@staticmethod
def _create_operation_args(node, mapper):
"""
Create operation args.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
node.data.param_transform(mapper)

def update_hierarchical_order(self) -> NoReturn:
"""
Update hierarchical order.
"""
hierarchical_order = dict()
queue = Queue()
queue.put(item=(self.root, self.ROOT_LEVEL), block=False)
while not queue.empty():
node_name, cur_level = queue.get(block=False)
node_inst = self[node_name]
if cur_level not in hierarchical_order:
hierarchical_order[cur_level] = []
hierarchical_order[cur_level].append(node_name)
for successor_name in node_inst.successors(self.tree_identifier):
queue.put(item=(successor_name, cur_level + 1), block=False)
self._hierarchical_order = hierarchical_order

def sub_graph_merging(self) -> NoReturn:
"""Shrink the module has only one child."""
self.update_hierarchical_order()
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)
for depth in depths:
for node_name in self._hierarchical_order[depth]:
node_inst = self[node_name]
# If the node type is module and has only one child,
# then merge it with its child.
if node_inst.data.node_type == NodeType.MODULE.value and \
len(node_inst.successors(self.tree_identifier)) == 1:
self.shrink(node_inst)

def _adjust_structure(self) -> NoReturn:
"""Adjust tree structure to generate source code."""
self.sub_graph_merging()
self.update_hierarchical_order()

+ 10
- 0
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -171,3 +171,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _find_val_by_index(loc_index, values_dict):
"""Find value by location index of values_dict."""
result = None
for idx, dict_val in enumerate(values_dict.values()):
if idx == loc_index:
result = dict_val
break
return result

+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -42,7 +42,7 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from PyTorch to MindSpore"""
weights = kwargs['weights']
params = kwargs['params']
weight = weights['weight'].numpy()
weight = weights['weight']
weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0])
if isinstance(params['dilations'], list):
dilation = tuple(params['dilations'])
@@ -130,7 +130,6 @@ class ConvMapper(ONNXToMindSporeMapper):
dim = len(kernel_size)
return f"nn.Conv{dim}d"

weight = weight.numpy()
dim = weight.ndim - 2
return f"nn.Conv{dim}d"



+ 6
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting

@@ -27,8 +28,11 @@ class DenseMapper(ONNXToMindSporeMapper):
@staticmethod
def _convert_params(**kwargs):
weights = kwargs['weights']
has_bias = bool('bias' in weights)
weight = weights['weight'].numpy().transpose()
weight_index = 0
bias_index = 1
bias = DenseMapper._find_val_by_index(bias_index, weights)
has_bias = isinstance(bias, np.ndarray)
weight = DenseMapper._find_val_by_index(weight_index, weights).transpose()
in_channels, out_channels = weight.shape
return {
'in_channels': in_channels,


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting

@@ -47,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper):
def _convert_params(**kwargs):
weights = kwargs.get("weights")
params = kwargs.get("params")
mode = params.get('mode', 'constant')
mode = convert_bytes_string_to_string(params.get('mode', 'constant'))
pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist()
if mode == 'constant' and params.get('value') is None:
if params.get('pads') or weights:


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -36,7 +36,7 @@ class PoolMapper(ONNXToMindSporeMapper):
transformed_params["kernel_size"] = tuple(params['kernel_shape'])
transformed_params["stride"] = tuple(params['strides'])
if "pads" in params:
if sum(params['pads']) == 0:
if sum(params['pads']) == 0 and not params.get('ceil_mode', None):
pad_mode = '\"valid\"'
else:
pad_mode = '\"same\"'


+ 7
- 12
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,14 +39,9 @@ class GraphFactory:
Returns:
Graph, graph instance.
"""
if all([input_nodes, output_nodes]):
onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

pytorch_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph')
pytorch_graph = getattr(pytorch_graph_module, 'PyTorchGraph')
return pytorch_graph.load(model_path=graph_path, sample_shape=sample_shape)

onnx_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph')
onnx_graph = getattr(onnx_graph_module, 'OnnxGraph')
return onnx_graph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -264,7 +264,7 @@ class Graph(BaseGraph, abc.ABC):
Returns:
cls, graph instance.
"""
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
src_graph = cls.load_graph(graph_path=model_path, sample_shape=sample_shape, **kwargs)
ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None

if ckpt is not None:


+ 18
- 5
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,14 @@
# limitations under the License.
# ==============================================================================
"""Define ONNX graph."""
from importlib import import_module
from typing import Dict, NoReturn

from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader

@@ -151,7 +153,7 @@ class OnnxGraph(Graph):
input_shape (tuple): Input shape.
"""
input_node = InputNode(input_shape)
input_node_name = self._raw_input_nodes.replace(":0", "")
input_node_name = self._raw_input_nodes
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)
@@ -196,7 +198,18 @@ class OnnxGraph(Graph):
"""
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
onnx_model = TFGraphParser.parse(graph_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
if graph_path.endswith('.pb'):
onnx_model = TFGraphParser.parse(graph_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
elif graph_path.endswith('.onnx'):
onnx = import_module('onnx')
onnx_model = onnx.load(graph_path)
optimizer = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer')
onnx_simplify = getattr(optimizer, 'OnnxSimplify')()
onnx_model = onnx_simplify.run_onnx_simplify(onnx_model, kwargs['sample_shape'])

else:
onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs)
return onnx_model

+ 19
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -112,10 +112,10 @@ class OnnxTensor:

def to_array(self):
"""Convert the tensor value from binary to np array."""
onnx = import_module("onnx")
numpy_helper = import_module("onnx.numpy_helper")
# Convert binary data to np.array
if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)):
return onnx.numpy_helper.to_array(self.raw_tensor)
return numpy_helper.to_array(self.raw_tensor)
return self.raw_tensor


@@ -383,15 +383,24 @@ class OnnxDataLoader:
"""Parse each onnx nodes in the model."""
nodes_topo_idx = []
for idx, node in enumerate(self.nodes):
if not node.name:
node.name = "_".join(node.output)
n = OnnxNode(node)
self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1:
raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name

for ipt_nd in node.input:
if ipt_nd not in self.output_name_to_node_name:
if self._global_context.onnx_node_inputs.get(n.name):
self._global_context.onnx_node_inputs[n.name].append(ipt_nd)
else:
self._global_context.onnx_node_inputs[n.name] = [ipt_nd]

self._global_context.onnx_node_name_to_topo_idx[n.name] = idx
node_inputs = [i.replace(":0", "") for i in node.input]
self._global_context.onnx_node_inputs[n.name] = node_inputs

self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx

@@ -449,7 +458,11 @@ class OnnxDataLoader:
input_node = self.get_node(input_node_name)
node.precursor_onnx_node_dict[input_node_name] = input_node
input_node.successor_onnx_node_dict[node_name] = node
continue

if self._global_context.onnx_node_inputs.get(node.name):
self._global_context.onnx_node_inputs[node.name].append(input_node_name)
else:
self._global_context.onnx_node_inputs[node.name] = [input_node_name]

def initialize(self):
"""Initialize the OnnxDataLoader."""


+ 144
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define ONNX optimizer operations."""
import copy
from importlib import import_module

import numpy as np

from ..common.utils import fetch_output_from_onnx_model


class OnnxSimplify:
"""To simplify onnx model."""
def __init__(self):
self._onnx_model = None
self._constant_nodes = list()
self._outputs_infer = dict()

def run_onnx_simplify(self, onnx_model, sample_shape):
"""
Run to simplify onnx model.

Args:
onnx_model (onnx.ModelProto): Onnx Model.
sample_shape (tuple): Sample shape of input.
"""
self._onnx_model = onnx_model
self._optimizer()
self._get_constant_nodes()
self._onnx_infer(sample_shape)
self._replace_constant_nodes()
self._optimizer()

return self._onnx_model

def _optimizer(self):
"""Run optimizer from onnx to eliminate constant nodes."""

onnxoptimizer = import_module('onnxoptimizer')
optimizers_list = [
'eliminate_deadend',
'eliminate_nop_dropout',
'eliminate_nop_cast',
'eliminate_nop_monotone_argmax',
'eliminate_nop_pad',
'extract_constant_to_initializer',
'eliminate_unused_initializer',
'eliminate_nop_transpose',
'eliminate_identity',
'fuse_add_bias_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_log_softmax',
'fuse_consecutive_reduce_unsqueeze',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
'fuse_transpose_into_gemm'
]

input_num = len(self._onnx_model.graph.input)
onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True)

if self._onnx_model.ir_version > 3:
del onnx_model_optimized.graph.input[input_num:]
self._onnx_model = onnx_model_optimized

def _get_constant_nodes(self):
"""Get constant nodes."""

const_nodes = list()
const_tensors = [tensor_init.name for tensor_init in self._onnx_model.graph.initializer]
const_tensors.append([node.output[0]
for node in self._onnx_model.graph.node if node.op_type == 'Constant'])

for node in self._onnx_model.graph.node:
if node.op_type == 'Shape' or all([input_node in const_tensors for input_node in node.input]):
const_nodes.append(node)
const_tensors.extend(node.output)

self._constant_nodes = copy.deepcopy(const_nodes)

def _onnx_infer(self, infer_inputs_shape):
"""
Run onnx inference to get outputs of constant nodes.

Args:
infer_inputs_shape (tuple): Input shape for running inference.
"""

input_onnx = self._onnx_model.graph.input[0]
input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}

output_nodes_name = list()
for node in self._constant_nodes:
output_nodes_name.extend(node.output)

self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, feed_dict, output_nodes_name)

def _replace_constant_nodes(self):
"""Replace constant nodes to nodes with op_type 'Constant'."""

onnx = import_module('onnx')
np_helper = import_module('onnx.numpy_helper')

for i, node in enumerate(self._onnx_model.graph.node):
if node in self._constant_nodes:
for output in node.output:
new_attr = onnx.helper.make_attribute(
'value',
np_helper.from_array(self._outputs_infer[output], name=output)
)

new_node = onnx.helper.make_node(
op_type='Constant',
inputs=list(),
outputs=[output],
name='_'.join(('node', output))
)
new_node.attribute.extend([new_attr])
self._insert_node(self._onnx_model.graph.node, i + 1, new_node)
del self._onnx_model.graph.node[i]

@staticmethod
def _insert_node(repeated_container, index, node):
"""Insert node into onnx model."""

repeated_container.extend([repeated_container[-1]])
for i in reversed(range(index + 1, len(repeated_container) - 1)):
repeated_container[i].CopyFrom(repeated_container[i - 1])
repeated_container[index].CopyFrom(node)

+ 0
- 691
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -1,691 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import os
import re
import warnings
from copy import deepcopy
from importlib import import_module
from typing import Dict, NoReturn

import numpy as np

from mindinsight.conf import settings
from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .torch_utils import set_opset_version
from ..common.utils import fetch_output_from_onnx_model

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME
from ..constant import LEFT_BUCKET, RIGHT_BUCKET
from ...common.exceptions import ModelNotSupportError

NONE_SCOPE_OP = {
"onnx::Add": "Add",
"onnx::Flatten": "Flatten",
"onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze",
"onnx::Split": "Split",
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean",
"onnx::Resize": "Resize",
"onnx::Pad": "Pad"
}

CONSTANT_NODES_PATTERN = {
"onnx::Resize": [
'onnx::Concat',
'onnx::Slice',
'onnx::Cast',
'onnx::Concat',
'onnx::Unsqueeze',
'onnx::Floor',
'onnx::Mul',
'onnx::Cast',
'onnx::Gather',
'onnx::Shape'
],
"onnx::Pad": [
'onnx::Cast',
'onnx::Concat',
'onnx::ConstantOfShape',
'onnx::Sub',
'onnx::Mul',
'onnx::Div',
'onnx::Gather',
'onnx::Shape',
'onnx::Unsqueeze',
'onnx::Slice',
'onnx::Reshape',
'onnx::Transpose'
],
"onnx::Constant": list()
}


def normalize_scope_name(node, scope_name_dict):
"""
Rename scope name into uniform.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Returns:
str, normalized scope name.
"""
global NONE_SCOPE_OP

scope_name = node.scopeName()
if not scope_name:
name = [retrieve_scope_name(node, scope_name_dict)]
else:
name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = []
for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0]
left = segment.find(LEFT_BUCKET)
right = segment.find(RIGHT_BUCKET)
if left != -1:
if segment[left + 1: right].isdigit():
scopes.append(f"{segment[:left]}_{segment[left + 1: right]}")
else:
scopes.append(segment[left + 1: right])
else:
scopes.append(segment)
if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()])
scopes = [s for s in scopes if s]
node_id = PyTorchGraph.get_node_id(node)
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}"


def retrieve_scope_name(node, scope_name_dict):
"""
Retrieve scope name from input nodes.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Return:
str: Scope name.
"""
node_content = \
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:])
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")

scope_name_ipt_nodes = list()
for node_input in node_inputs:
if not scope_name_dict.get(node_input, None):
continue
scope_name_ipt_nodes.append(scope_name_dict[node_input])

scope_name_split = list()
for idx, _ in enumerate(scope_name_ipt_nodes):
if not scope_name_split:
scope_name_split = scope_name_ipt_nodes[idx]
else:
scope_name_split = [
sub_scope_name
for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx]
]
scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split)
return scope_name


class PyTorchGraph(Graph):
"""
Define PyTorch graph.

Args:
model (Module): PyTorch model.
sample_shape (tuple): Input shape of the model.

"""

def __init__(self, model, sample_shape: tuple):
super(PyTorchGraph, self).__init__(model=model)

from .torch_utils import unique_state_dict

self._params_dict = unique_state_dict(model)
self._original_shape = list()
self._nodes = list()
self._constant_nodes = list()
self._dynamic_nodes = list()
self._has_eliminated_nodes = False
self._file_graph_onnx = os.path.join(
settings.WORKSPACE, 'log/mindconverter/'
)

self.build(sample_shape)

@staticmethod
def _check_input_shape(input_shape):
"""
Check input shape.

Args:
input_shape (tuple): Input tensor shape.

"""
if not input_shape:
err_msg = "`input_shape` can not be None."
log.error(err_msg)
raise ValueError(err_msg)

for item in input_shape:
if not isinstance(item, int):
err_msg = "Only support model with one input now, " \
"and each shape value in `input_shape` should be int."
log.error(err_msg)
raise ValueError(err_msg)

@staticmethod
def _extract_shape(shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
if "," not in shape:
return []

shape_arr = []
for s in shape.split(","):
s = s.strip()
if not s:
return []
if ":" in s:
s = s.split(":")[0]
s = s.replace("!", "")
if not s.isdigit():
return []
shape_arr.append(int(s))
return shape_arr

def _trace_torch_graph(self, input_shape):
"""
Trace torch computational graph.

Args:
input_shape (tuple): Shape.

Returns:
object, pytorch graph.
"""
import torch
from torch.onnx import OperatorExportTypes
from .torch_utils import OverloadTorchModuleTemporarily
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

warnings.simplefilter("ignore")

batched_sample = create_autograd_variable(torch.rand(*input_shape))

try:
try:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
set_opset_version(ONNX_OPSET_VERSION)
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError as error:
log.error(str(error))
log.exception(error)
raise error

def build(self, input_shape):
"""
Build graph tree.

Args:
input_shape (tuple): Input shape of model.

"""
self._check_input_shape(input_shape)
self._original_shape = input_shape
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())
self._nodes = nodes

scope_name_dict = dict()

self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes)

for node in nodes:
output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs()))
if output_name in self._dynamic_nodes:
continue

node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1]
output_shape = self._extract_shape(output_shape_str)
weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)]', node.scopeName())
)

if self._constant_nodes:
node_weight = self._replace_constant_node(node)
else:
node_weight = {}

for scope, weight in self._params_dict.items():
split_scope = scope.split('.')
if '.'.join(split_scope[:-1]) == weight_scope:
node_weight[split_scope[-1]] = weight

if not node_weight and node.kind() == 'onnx::Conv':
weight_names = list(self._params_dict.keys())
node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_input_names:
if int(node_input_name) > len(weight_names):
continue
weight = self._params_dict[weight_names[int(node_input_name) - 1]]
node_weight[weight_names[int(node_input_name) - 1]] = weight

self._shape_dict[node_name] = output_shape
self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight)
self._nodes_record[node_name] = node_name

for node_input in list(node.inputs()):
if self._extract_input_name(node_input) in self._constant_nodes:
continue
# Connect input node and src node.
nd_id = PyTorchGraph.get_node_id(node_input.node())
nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \
node_input.node().scopeName()

if nd_id and nd_scope_name:
node_input_name = normalize_scope_name(
node_input.node(), scope_name_dict
)
self.build_connection(node_input_name, node_name)

self._unmerge_multi_ipt_opt_script()

super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)

@staticmethod
def _extract_node_name(node):
"""Extract node name for node."""
result = re.match(r"\d+", str(node))
if result:
return result.group(0)
return None

@staticmethod
def _extract_input_name(node_input):
"""Extract node input name from node input."""
node_input_name = str(node_input).split('defined in')[0].strip()
return node_input_name

def _get_constant_nodes(self, nodes):
"""
Get constant nodes to be eliminated.

Args:
nodes (Nodes): Nodes in torch._C.Graph.

Returns:
Union(dict, list), output of constant_input_node_name and dynamic nodes name.
"""
constant_input_nodes = list()
dynamic_nodes = list()
for node in nodes:
if node.kind() == 'onnx::Resize':
self._has_eliminated_nodes = True
constant_input_node, dynamic_node = self._generate_inputs_of(node)
constant_input_nodes += constant_input_node
dynamic_nodes += dynamic_node

outputs = dict()
if self._has_eliminated_nodes:
torch = import_module('torch')
device_target = 'cuda' if torch.cuda.is_available() else 'cpu'
dump_input = torch.randn(*self._original_shape, device=device_target)
temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx,
'.graph_onnx.onnx'))

symbolic_helper = import_module('torch.onnx.symbolic_helper')
export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version')
try:
torch.onnx.export(self.model.to(device_target), dump_input,
temp_onnx_path, opset_version=export_onnx_opset_version)

outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape)
finally:
if os.path.exists(temp_onnx_path):
os.remove(temp_onnx_path)

return outputs, dynamic_nodes

def _generate_inputs_of(self, node):
"""
Generate inputs of certain node.

Args:
node (Node): Node of torch._C.Graph.

"""
pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None)
constant_input_nodes = list()
dynamic_nodes = list()
if not isinstance(pattern_op_lst, list):
return constant_input_nodes, dynamic_nodes
if not pattern_op_lst:
dynamic_nodes += self.get_node_id(node)
return constant_input_nodes, dynamic_nodes

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]

for node_input_name in node_inputs_name:
node_name_path = self._search_node_path(node_input_name, pattern_op_lst)
if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape':
constant_input_nodes.append(node_input_name)
dynamic_nodes += node_name_path

return constant_input_nodes, dynamic_nodes

def _search_node_path(self, node_name, pattern_op_lst):
"""
Search node path based on pattern_op_list.

Args:
node_name (str): Node name.
pattern_op_lst (list): Pattern list of certain operator.

Returns:
list[str]: node names in pattern.
"""
node_type_lst = list()
node_name_lst = list()
node = self._get_node_from_graph(node_name)

if node_name == MODEL_INPUT_NAME:
return node_name_lst

if node.kind() not in pattern_op_lst:
return node_name_lst

node_type_lst.append(node.kind())
node_name_lst.append(node_name)

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_inputs_name:
node_name_lst += self._search_node_path(node_input_name, pattern_op_lst)

return node_name_lst

def _get_node_from_graph(self, node_name):
"""Get torch._C.Node from torch._C.Graph."""
for idx, node in enumerate(self._nodes):
node_id = ', '.join(self.get_node_id(node))
if node_id == node_name:
return self._nodes[idx]
return None

@staticmethod
def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape):
"""
Infer onnx model to get outputs of inner nodes.

Args:
file_graph_onnx (str): File path of onnx.
infer_outputs (list): Outputs for infer.
infer_inputs_shape (list): Input shape for infer.

"""
onnx = import_module('onnx')
tensor_proto = getattr(onnx, 'TensorProto')
onnx_model = onnx.load(file_graph_onnx)

for onnx_node in onnx_model.graph.node:
if set(onnx_node.output).issubset(set(infer_outputs)):
onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output])

input_onnx = onnx_model.graph.input[0]
node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type)
if node_type != 'FLOAT':
raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. "
f"Please report issue to us if extra input type is needed.")

input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}
outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs)

return outputs

def _replace_constant_node(self, node):
"""Replace constant node."""
node_weight = dict()
for node_input in list(node.inputs()):
node_input_name = self._extract_input_name(node_input)
if node_input_name in self._constant_nodes:
node_weight[node_input_name] = self._constant_nodes[node_input_name]
return node_weight

def _collect_ipt_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.

Args:
input_shape (tuple): Input shape.

"""
input_node = InputNode(input_shape)
input_node_name = "{}InputNode"
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)
input_node.set_scope_name(node.scope_name)
node.precursor_nodes.insert(0, ipt_nd_name)
input_node.set_successor_nodes(node_name)
self._shape_dict[ipt_nd_name] = input_node.output_shape

if not self._shape_dict[node_name]:
self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE

ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)

self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape

def _generate_module(self):
"""Generate modules."""
module_dict = dict()
for node_key, _ in self._nodes_collection.items():
node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE)
if len(node_key_in_scope) < MIN_SCOPE_LENGTH:
continue

for idx in range(1, len(node_key_in_scope)):
node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx])
node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1])
if not module_dict.get(node_key_module, None):
module_dict[node_key_module] = {node_name}
else:
module_dict[node_key_module].add(node_name)

return module_dict

def _check_multi_ipt_opt(self):
"""Check whether multi-input exists."""
module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set()
for node_name in nodes_per_module:
if re.search(r"[\d]+[&][\d]+", node_name):
self._is_multi_opt_graph = True
return True

node = self._nodes_collection.get(node_name, None)
if node:
prcs_nodes = node.precursor_nodes
else:
continue

for prcs_node in prcs_nodes:
if prcs_node not in nodes_per_module:
prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1])
if prcs_node_module not in nodes_per_module:
prcs_nodes_out_from_module.add(prcs_node)

if len(prcs_nodes_out_from_module) > 1:
return True

return False

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt_opt() or self._has_eliminated_nodes:
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes

node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)]
node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0],
scsr_node.split(SEPARATOR_IN_SCOPE)[-1]))
for scsr_node in deepcopy(scsr_nodes)]

reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._nodes_collection[node_key]
self._nodes_collection[reduce_node_key] = node_inst

for node_key, shape in deepcopy(self._shape_dict).items():
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._shape_dict[node_key]
self._shape_dict[reduce_node_key] = shape

def sub_graph_merging(self):
"""
Merge split operation into one.
"""
raise NotImplementedError()

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.

Args:
src (str): Source node name.
tgt (str): Target node name.

"""
# If src and tgt are the same node, src not in node_collection or
# tgt not in node_collection, then skip this edge.
if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
if src.split(':')[0] not in self._nodes_collection:
log.warning("Graph construct a self-loop node %s. Ignored.", src)
return
if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt)
if src not in self._nodes_collection[tgt].precursor_nodes:
self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src)

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
"""
Load checkpoint.

Args:
ckpt_path (str): Checkpoint file path.

Returns:
dict, weights in model.
"""

@staticmethod
def load_metadata(**kwargs):
"""
Load graph metadata.
"""
err_msg = "class `PyTorchGraph` has not implemented " \
"`load_metadata()`."
log.error(err_msg)
raise NotImplementedError(err_msg)

@staticmethod
def load_graph(graph_path: str, **kwargs):
"""
Load graph.

Args:
graph_path (str): Graph path.

Returns:
object, pytorch model.
"""
torch_model = PyTorchGraphParser.parse(graph_path)
return torch_model

@staticmethod
def get_node_id(node):
"""
Get node id using regular expr.

Args:
node (Node): PyTorch node.

Returns:
str, node id.
"""
node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_id = re.findall(r"[%](.*?) [:]", node_title)
return node_id

+ 0
- 236
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -1,236 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
import re

from .base import GraphNode
from ..common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT


class PyTorchGraphNode(GraphNode):
"""
PyTorch graph node.

Args:
node (torch._C.Node): Node in raw PyTorch graph.

"""

_type_frozen = False
_module_name_frozen = False

def __init__(self, node=None, weight=None):
super(PyTorchGraphNode, self).__init__(node=node)
self._op_params = self._get_raw_params(node)
self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None
self._weight = weight
self._ipt_var_names, self._opt_var_names \
= self._extract_ipt_opt_var_names() if node else (list(), list())

def _extract_ipt_opt_var_names(self):
"""Extract ipt and opt var names."""
node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(
str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]
)
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")
node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_outputs = re.findall(r"[%](.*?) [:]", node_title)
return node_inputs, node_outputs

def clear_args_of_declaration(self):
"""
Clear `self._args_in_code`.
"""
self._args_in_code = dict()

def _get_arg_name(self, arg, variable_name):
"""
Get arg name.

Args:
arg (str): Generate arg name.

Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{variable_name}"

@property
def is_in_multi_opt_graph(self):
return self._is_in_multi_opt_graph

@is_in_multi_opt_graph.setter
def is_in_multi_opt_graph(self, multi_opt_state):
self._is_in_multi_opt_graph = multi_opt_state

@property
def hash_key(self):
"""
Return unique hash key of current node.

Returns:
str, hash key.
"""
if self._node_type not in {NodeType.CLASS.value,
NodeType.FUNC.value,
NodeType.MODULE.value}:
self._hash_key = self._op_name.lower()
return self._hash_key

@hash_key.setter
def hash_key(self, h):
"""
Setter of hash key.

Args:
h (str): Key.

"""
self._hash_key = h

@property
def op_name(self):
"""
Op name in torch.

Returns:
str, op name.
"""
return self._op_name

@op_name.setter
def op_name(self, name):
"""
Setter of op name.

Args:
name(str): op_name.

"""
self._op_name = name

@property
def real_name(self):
return

def add_input_and_output_shape(self, input_shape, output_shape):
"""
Add the node input shape.

Args:
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

"""
self._ipt_shape = input_shape
self._opt_shape = output_shape

def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment):
"""
Generate statements.

Args:
variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input.
output_var (list): Output variable names in construct.
code_fragment (CodeFragment): CodeFragment instance.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = code_fragment.operation

args = self.args_in_code
settings = code_fragment.code_setting

if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation):
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct, settings)
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = ipt_args_in_construct

if SEPARATOR_IN_ONNX_OP in operator:
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")

declare = f"self.{variable_name} = {operator}({expr})"
call = f"{', '.join([output for output in output_var])}" \
f" = self.{variable_name}({ipt_args_settings_in_construct})"

return declare, call

def to_ir(self):
"""
No need to implement for now.
"""
raise NotImplementedError

def _get_raw_params(self, node):
"""
Get params in onnx.

Args:
node (Any): Node.

Returns:
dict, raw params.
"""
from .torch_utils import getitem_of_node

raw_params = dict()

if not node:
return raw_params

for k in node.attributeNames():
raw_params[k] = getitem_of_node(node, k)
return raw_params

def replace_with_arg(self, src_arg, tgt_arg):
"""
Replace actual parameter with formal parameter.

Args:
src_arg (str): Original arg name.
tgt_arg (str): Target arg name.

"""
self._args_in_code[src_arg] = tgt_arg

@staticmethod
def _extract_var_name(scope_name: str):
"""
Extract variable name from scope name.
"""
if not scope_name:
return None
var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower()
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

+ 59
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from .optimizer import OnnxSimplify
from ...common.exceptions import ModelNotSupportError


@@ -38,7 +39,6 @@ class PyTorchGraphParser(GraphParser):
Returns:
object, torch model.
"""
torch = import_module("torch")

if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
@@ -47,14 +47,66 @@ class PyTorchGraphParser(GraphParser):
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
onnx_model_sim = cls._convert_pytorch_graph_to_onnx(
model_path, kwargs['sample_shape'], opset_version=11)
return onnx_model_sim
except ModuleNotFoundError:
error_msg = "Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
raise error

return model
@staticmethod
def _convert_pytorch_graph_to_onnx(model_path, sample_shape, opset_version=None):
"""
Convert Pytorch model to ONNX model.

Args:
model_path (str): Path to the Pytorch model.
sample_shape (tuple): Input shape to generate onnx model.
opset_version (int): Op set version of onnx.
"""

torch = import_module('torch')
has_cuda = torch.cuda.is_available()
if has_cuda:
model = torch.load(f=model_path).cuda()
dump_input = torch.randn(*sample_shape, device='cuda')
else:
model = torch.load(f=model_path, map_location="cpu")
dump_input = torch.randn(*sample_shape, device='cpu')

if isinstance(model, torch.nn.DataParallel):
raise ValueError('torch.nn.DataParallel is not supported by ONNX exporter.')

torch_onnx = import_module('torch.onnx')
operator_export_types = getattr(torch_onnx, 'OperatorExportTypes')
utils = import_module('torch.onnx.utils')
model_to_graph = getattr(utils, '_model_to_graph')

symbolic_helper = import_module('torch.onnx.symbolic_helper')
default_onnx_opset_version = getattr(symbolic_helper, '_default_onnx_opset_version')
set_opset_version = getattr(symbolic_helper, '_set_opset_version')
set_operator_export_type = getattr(symbolic_helper, '_set_operator_export_type')
if not opset_version:
opset_version = default_onnx_opset_version

operator_export_type = operator_export_types.ONNX
set_opset_version(opset_version)
set_operator_export_type(operator_export_type)

graph, params_dict, _ = model_to_graph(model, dump_input, _retain_param_name=True)
export_onnx = getattr(graph, '_export_onnx')
proto, _ = export_onnx(
params_dict, opset_version, dict(), False,
operator_export_type, True, False, dict(),
True, False)

onnx = import_module('onnx')
onnx_model = onnx.load_model_from_string(proto)

onnx_simplify = OnnxSimplify()
onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape)

return onnx_model_sim

+ 0
- 105
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -1,105 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define pytorch tracer context manager."""
import importlib

from torch.nn import Module
from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem
from torch.onnx.symbolic_helper import _set_opset_version


SCRIPT_METHOD = getattr(importlib.import_module("torch._C"),
"ScriptMethod")
onnx_tracer = _trace
getitem_of_node = _node_getitem
set_opset_version = _set_opset_version


def unique_state_dict(model):
"""
Wrapper of torch.jit._unique_state_dict.

Args:
model (Module): Torch model.

Returns:
dict, params.
"""
from torch.jit import _unique_state_dict

return _unique_state_dict(model)


def create_autograd_variable(tensor):
"""
Create autograd variable to trace the whole graph.

Args:
tensor (torch.Tensor): Tensor.

Returns:
torch.autograd.Variable, variable.
"""
variable = getattr(importlib.import_module("torch.autograd"), "Variable")
return variable(tensor, requires_grad=False)


class OverloadTorchModuleTemporarily:
"""
Fix bugs in new version of pytorch.
PyTorch official solution.
"""

def __init__(self):
self.backup = None

def __enter__(self):
def _tracing_name(traced_module, tracing_state):
traced_module_stack = getattr(tracing_state, "_traced_module_stack")
if not traced_module_stack:
return None
module = traced_module_stack[-1]
for name, child in module.named_children():
if child is traced_module:
return name
return None

def _slow_forward(self_, *inputs, **kwargs):
tracing_state = getattr(importlib.import_module("torch._C"),
"_get_tracing_state")()
if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD):
return self_.forward(*inputs, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = _tracing_name(self_, tracing_state)
get_name_func = getattr(self_, "_get_name")
if name:
tracing_state.push_scope('%s[%s]' % (get_name_func(), name))
else:
tracing_state.push_scope(get_name_func())
tracing_state._traced_module_stack.append(self_)
try:
result = self_.forward(*inputs, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result

self.backup = getattr(Module, "_slow_forward")
setattr(Module, '_slow_forward', _slow_forward)

def __exit__(self, exc_type, exc_val, exc_tb):
setattr(Module, '_slow_forward', self.backup)

tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py → tests/ut/mindconverter/graph_based_converter/common/test_name_mgr.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# ==============================================================================
"""Test name manager module."""
from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import NameMgr, GlobalVarNameMgr, \
global_op_namespace



+ 0
- 15
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -1,15 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit test for mindconvert.graph_based_converter.hierarchical_tree interface."""

+ 0
- 177
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py View File

@@ -1,177 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test hierarchical tree module."""
import os
import shutil
from unittest import mock

import pytest

from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.constant import NodeType

from tests.ut.mindconverter.graph_based_converter.conftest import TEST_BASE_PATH


class TestHierarchicalTree:
"""Test the class of HierarchicalTree."""

def test_tree_identifier(self):
"""Test tree_identifier"""
tree = HierarchicalTree()
assert isinstance(tree.tree_identifier, str)

@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_insert(self, get_raw_params):
"""Test insert"""
get_raw_params.return_value = []
tree = HierarchicalTree()
pt_node = PyTorchGraphNode()
tree.insert(pt_node, 'ResNet')
assert tree.root == 'ResNet'

def test_remove(self):
"""Test remove function."""
tree = HierarchicalTree()
tree.create_node(
tag='node_root',
identifier='root',
parent=None,
data=None
)
node = tree.get_node('root')
tree.remove(node)
assert tree.root is None

@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_shrink(self, get_raw_params):
"""Test shrink function."""
params = {'root': {},
'root/child0': {},
'root/child0/child1': {}}
tree = self._create_tree(get_raw_params=get_raw_params, params=params)
node = tree.get_node('root/child0')
tree.shrink(node)
assert tree.leaves()[0].tag == 'child1'

@pytest.mark.parametrize('params', [{
'tree_params': {'root': {'op_name': 'Root',
'precursor_nodes': [],
'successor_nodes': ['root/relu'],
'node_type': NodeType.MODULE.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]},
'root/relu': {'op_name': 'onnx::Relu',
'precursor_nodes': ['root'],
'successor_nodes': ['root/unknown'],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 3, 224, 224]},
'root/unknown': {'op_name': 'onnx::Unknown',
'precursor_nodes': ['root/relu'],
'successor_nodes': [],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]}},
'report_dir': 'report_folder'
}, {
'tree_params': {'root': {'op_name': 'Root',
'precursor_nodes': [],
'successor_nodes': ['root/relu'],
'node_type': NodeType.MODULE.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]},
'root/relu': {'op_name': 'onnx::Relu',
'precursor_nodes': ['root'],
'successor_nodes': ['root/unknown'],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 3, 224, 224]},
'root/unknown': {'op_name': 'onnx::Unknown',
'precursor_nodes': ['root/relu'],
'successor_nodes': [],
'node_type': NodeType.OPERATION.value,
'input_shape': [1, 3, 224, 224],
'output_shape': [1, 1, 224, 224]}},
'report_dir': None
}])
@mock.patch(
'.'.join((TEST_BASE_PATH, 'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')))
def test_save_source_file(self, get_raw_params, params):
"""Test save_source_file function."""
tree_params = params['tree_params']
out_folder = 'out_folder'
report_folder = params['report_dir']
model_name = 'model_name'
mapper = ONNXToMindSporeMapper()

tree = self._create_tree(get_raw_params=get_raw_params, params=tree_params)
tree.save_source_files(out_folder, mapper, model_name, report_folder)

out_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py"))
report_folder_test = report_folder if report_folder else out_folder
report_path = os.path.realpath(
os.path.join(report_folder_test, f"report_of_{model_name}.txt"))
try:
assert os.path.exists(out_path)
assert os.path.exists(report_path)
with open(out_path, 'r') as out_r:
code = out_r.read()
assert 'nn.ReLU' in code
assert 'onnx.Unknown' in code
with open(report_path, 'r') as report_r:
report = report_r.read()
assert "[UnConvert] 'onnx::Unknown' didn't convert." in report
assert "Converted Rate: 50.00%." in report
finally:
shutil.rmtree(out_folder)
if report_folder:
shutil.rmtree(report_folder)

@staticmethod
def _create_node(key, val, weight, input_shape, output_shape):
"""Create node."""
node = PyTorchGraphNode(weight=weight)
node.add_input_and_output_shape(input_shape, output_shape)
node.tag = key.split('/')[-1] if len(key.split('/')) > 1 else key
node.op_name = val['op_name'] if val.get('op_name') else None
node.precursor_nodes = val['precursor_nodes'] if val.get('precursor_nodes') else []
node.successor_nodes = val['successor_nodes'] if val.get('successor_nodes') else []
node.node_type = val['node_type'] if val.get('node_type') else None
return node

@staticmethod
def _create_tree(get_raw_params, params):
"""Create tree."""
tree = HierarchicalTree()
for key, val in params.items():
input_shape = val['input_shape'] if val.get('input_shape') else []
output_shape = val['output_shape'] if val.get('output_shape') else []
get_raw_params.return_value = val['op_params'] if val.get('op_params') else dict()
weight = val['weight'] if val.get('weight') else None

node = TestHierarchicalTree._create_node(key, val, weight, input_shape, output_shape)

tree.create_node(
tag=node.tag,
identifier=key,
parent='/'.join(key.split('/')[:-1]) if len(key.split('/')) > 1 else None,
data=node
)
return tree

+ 2
- 2
tests/ut/mindconverter/graph_based_converter/mapper/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit test for mindconvert.graph_based_converter.mapper interface."""
"""Unit test for mindconverter.graph_based_converter.mapper interface."""

+ 5
- 6
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ import pytest

from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from tests.utils import mindspore


class TestMappers:
@@ -30,7 +29,7 @@ class TestMappers:
'group': 1,
'pads': [1, 2, 3, 4],
'strides': [1, 1]},
'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 1, 1], dtype=np.int32))}},
'weights': {'weight': np.zeros((64, 3, 1, 1), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Conv2d',
'converted_params': {'in_channels': 3,
'out_channels': 64,
@@ -47,7 +46,7 @@ class TestMappers:
'group': 1,
'pads': [0, 0, 0, 0],
'strides': [1, 1]},
'weights': {'weight': mindspore.Tensor(np.zeros([64, 3, 2, 2], dtype=np.int32))}},
'weights': {'weight': np.zeros((64, 3, 2, 2), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Conv2d',
'converted_params': {'in_channels': 3,
'out_channels': 64,
@@ -61,8 +60,8 @@ class TestMappers:
}, {
'input': {'op_name': 'onnx::Gemm',
'params': dict(),
'weights': {'weight': mindspore.Tensor(np.zeros([10, 3], dtype=np.int32)),
'bias': mindspore.Tensor(np.zeros([10, 1], dtype=np.int32))}},
'weights': {'weight': np.zeros((10, 3), dtype=np.int32),
'bias': np.zeros((10, 1), dtype=np.int32)}},
'expected_output': {'converter_name': 'nn.Dense',
'converted_params': {'in_channels': 3,
'out_channels': 10,


Loading…
Cancel
Save