Browse Source

Generate checkpoint

tags/v1.2.0-rc1
moran 4 years ago
parent
commit
5fb0d2e1d8
37 changed files with 479 additions and 176 deletions
  1. +31
    -1
      mindinsight/mindconverter/common/exceptions.py
  2. +3
    -3
      mindinsight/mindconverter/graph_based_converter/__init__.py
  3. +6
    -1
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  4. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  5. +33
    -3
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  6. +1
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  7. +48
    -30
      mindinsight/mindconverter/graph_based_converter/framework.py
  8. +4
    -4
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  9. +74
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  10. +7
    -7
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  11. +5
    -5
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/__init__.py
  13. +14
    -7
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  14. +11
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  15. +19
    -20
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  16. +7
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  17. +4
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  18. +1
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  19. +103
    -10
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  20. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  21. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py
  22. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py
  23. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py
  24. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py
  25. +3
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  26. +10
    -7
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  27. +10
    -7
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  28. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  29. +3
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  30. +9
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  31. +6
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  32. +20
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  33. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  34. +1
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  35. +3
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py
  36. +6
    -2
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py
  37. +26
    -0
      tests/utils/mindspore/train/serialization.py

+ 31
- 1
mindinsight/mindconverter/common/exceptions.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -301,6 +301,8 @@ class SourceFilesSaveError(MindConverterException):
NODE_INPUT_TYPE_NOT_SUPPORT = 1
SCRIPT_GENERATE_FAIL = 2
REPORT_GENERATE_FAIL = 3
CKPT_GENERATE_FAIL = 4
MAP_GENERATE_FAIL = 5

BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
@@ -315,6 +317,8 @@ class SourceFilesSaveError(MindConverterException):
except_source = (NodeInputTypeNotSupportError,
ScriptGenerationError,
ReportGenerationError,
CheckPointGenerationError,
WeightMapGenerationError,
IOError, cls)
return except_source

@@ -437,6 +441,32 @@ class ReportGenerationError(SourceFilesSaveError):
return ZeroDivisionError, cls


class CheckPointGenerationError(SourceFilesSaveError):
"""The checkpoint generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value

def __init__(self, msg):
super(CheckPointGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
return cls


class WeightMapGenerationError(SourceFilesSaveError):
"""The weight names map generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value

def __init__(self, msg):
super(WeightMapGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exception below."""
return cls


class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception."""



+ 3
- 3
mindinsight/mindconverter/graph_based_converter/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,5 +15,5 @@
"""Graph based scripts converter definition."""
__all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"]

from .framework import graph_based_converter_pytorch_to_ms
from .framework import graph_based_converter_tf_to_ms
from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_pytorch_to_ms
from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_tf_to_ms

+ 6
- 1
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -191,18 +191,23 @@ class CodeFragment(Fragment):
"""

def __init__(self, operation, actual_args, settings, input_shape, output_shape,
trainable_params=None):
trainable_params=None, trainable_weights=None):
super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)
self._trainable_params = dict() # External weights, like Matmul.
self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d.
self._trainable_weights = trainable_weights

@property
def trainable_params(self):
"""Return the trainable parameters."""
return self._trainable_params

@property
def trainable_weights(self):
return self._trainable_weights


class ModuleFragment(Fragment):
"""Manage module type code variables."""


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -14,7 +14,7 @@
# ==============================================================================
"""Define GlobalContext class to save required resources during whole conversion procedure."""
from collections import OrderedDict
from .outputs import OutputStorage
from mindinsight.mindconverter.graph_based_converter.common.outputs import OutputStorage


class Singleton(type):


+ 33
- 3
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -13,16 +13,21 @@
# limitations under the License.
# ============================================================================
"""Define common utils."""
import json
import os
import stat
from importlib import import_module
from importlib.util import find_spec
from typing import List, Tuple, Mapping

from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError
from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \
UnknownModelError, CheckPointGenerationError, WeightMapGenerationError
from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \
FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX

from mindspore.train.serialization import save_checkpoint


def is_converted(operation: str):
"""
@@ -96,7 +101,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
code_lines (dict): Code lines.
out_folder (str): Output folder.
report_folder (str): Report output folder.

"""
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR
@@ -114,7 +118,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
os.makedirs(report_folder, modes_usr)

for file_name in code_lines:
code, report = code_lines[file_name]
code, report, trainable_weights, weight_map = code_lines[file_name]
code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py"))
report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt"))
try:
@@ -133,6 +137,31 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
except (IOError, FileExistsError) as error:
raise ReportGenerationError(str(error))

ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt"))
try:
if os.path.exists(ckpt_file_path):
raise CheckPointGenerationError("Checkpoint file with the same name already exists.")
save_checkpoint(trainable_weights, ckpt_file_path)
except TypeError as error:
raise CheckPointGenerationError(str(error))

weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json"))
try:
if os.path.exists(weight_map_path):
raise WeightMapGenerationError("Weight map file with the same name already exists.")
with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f:
weight_map_json = {f"{model_name}": weight_map}
json.dump(weight_map_json, map_f)
except (IOError, FileExistsError) as error:
raise WeightMapGenerationError(str(error))


def onnx_satisfied():
"""Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation."""
if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"):
return False
return True


def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
newest_ver_limited: str = ""):
@@ -220,6 +249,7 @@ def reset_init_or_construct(template, variable_slot, new_data, scope):
template[variable_slot][scope] += new_data
return template


def replace_string_in_list(str_list: list, original_str: str, target_str: str):
"""
Replace a string in a list by provided string.


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -41,6 +41,7 @@ UNKNOWN_DIM_VAL = "unk__001"
ONNX_MIN_VER = "1.8.0"
TF2ONNX_MIN_VER = "1.7.1"
ONNXRUNTIME_MIN_VER = "1.5.2"
ONNXOPTIMIZER_MIN_VER = "0.1.2"


@unique


+ 48
- 30
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -21,10 +21,10 @@ from importlib.util import find_spec

import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
@@ -53,6 +53,18 @@ parser.add_argument("--report", type=str, required=False,
help="Generated reports output folder path.")


def onnx_lib_version_satisfied():
"""Check onnx libs version whether is satisfied."""
onnx = import_module("onnx")
ort = import_module("onnxruntime")
optimizer = import_module("onnxoptimizer.version")
if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER):
return False
return True


def torch_installation_validation(func):
"""
Validate args of func.
@@ -68,26 +80,33 @@ def torch_installation_validation(func):
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
f"are required when using graph based "
f"scripts converter, and PyTorch version must "
f"be consisted with model generation runtime.")
error_info = None
if graph_path.endswith('.onnx'):
if not onnx_satisfied():
error_info = f"onnx(>={ONNX_MIN_VER}, onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \
f"are required when using graph based scripts converter."
else:
if not find_spec("torch") or not onnx_satisfied():
error_info = f"PyTorch, onnx(>={ONNX_MIN_VER}), " \
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \
f"are required when using graph based " \
f"scripts converter, and PyTorch version must " \
f"be consisted with model generation runtime."
if error_info:
error = RuntimeIntegrityError(error_info)
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
sys.exit(0)

onnx = import_module("onnx")
ort = import_module("onnxruntime")

if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
if not onnx_lib_version_satisfied():
error = RuntimeIntegrityError(
f"onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and "
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) are required when using graph "
f"based scripts converter for Pytorch conversion."
)
log.error(error)
@@ -128,11 +147,11 @@ def tf_installation_validation(func):
output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed.
if not _check_tf_installation() or not find_spec("tf2onnx") \
or not find_spec("onnx") or not find_spec("onnxruntime"):
if not _check_tf_installation() or not onnx_satisfied():
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) "
f"are required when using graph "
f"based scripts converter for TensorFlow conversion."
)
log.error(error)
@@ -141,15 +160,14 @@ def tf_installation_validation(func):
log_console.error("\n")
sys.exit(0)

onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
ort = import_module("onnxruntime")
tf2onnx = import_module("tf2onnx")

if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
or not onnx_lib_version_satisfied():
error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) "
f"are required when using graph "
f"based scripts converter for TensorFlow conversion."
)
log.error(error)
@@ -258,12 +276,12 @@ def main_graph_base_converter(file_config):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")

if frame_type == FrameworkType.PYTORCH.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
input_nodes=file_config['input_nodes'] if file_config['input_nodes']
else 'input.1',
output_nodes=file_config['output_nodes'] if file_config['output_nodes']
else '',
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value:


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"]
import re
import copy

from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment, NewFragment
from ..common.outputs import NodeOutputManager
from ..constant import ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords


def _tf_model_node_name_reformat(node, node_name):


+ 74
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -16,6 +16,7 @@
import copy
from collections import OrderedDict

from mindspore import Tensor
from yapf.yapflib.yapf_api import FormatCode

from mindinsight.mindconverter.common.exceptions import GeneratorError
@@ -28,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list

@@ -469,6 +470,74 @@ class Generator:
"""Return all ModuleStructs in this model."""
return self._module_struct_collections

def generate_weight_scope_name(self, node):
"""Generate weight scope name for checkpoint."""
replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name
scope_list = self.node_structs[node].scope.scope_list
ms_var_name = self.node_structs[node].ms_var_name

weight_scope_name = None
for scope in scope_list[1:]:
replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0])
if replaced_module:
scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module)
if not weight_scope_name:
weight_scope_name = scope
else:
weight_scope_name = '.'.join((weight_scope_name, scope))

if not weight_scope_name:
weight_scope_name = ms_var_name
else:
weight_scope_name = '.'.join((weight_scope_name, ms_var_name))

return weight_scope_name.lower()

def generate_checkpoint(self):
"""Generate checkpoint."""

trainable_weights_dict = dict()
weight_map = list()
for node_name, node_inst in self.node_structs.items():
if node_inst.fragment.exchange_msg['var_0']['trainable_params']:
weights_scope_name = self.generate_weight_scope_name(node_name)
onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights']
for idx, (weight_key, weight_value) in \
enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()):
weight_name = '.'.join((weights_scope_name, weight_key))
weight_shape = Tensor(weight_value).shape
data_type = Tensor(weight_value).dtype
trainable_weights_dict[weight_name] = weight_value

onnx_weight_name = onnx_weight_inst[idx].name
onnx_weight_shape = onnx_weight_inst[idx].value.shape
onnx_data_type = onnx_weight_inst[idx].value.dtype

weight_map.append(
{
'converted_weight': {
'name': weight_name,
'shape': weight_shape,
'data_type': str(data_type)
},
'source_weight': {
'name': onnx_weight_name,
'shape': onnx_weight_shape,
'data_type': str(onnx_data_type)
}
}
)

save_obj = list()
for weight_name, weight_value in trainable_weights_dict.items():
obj = {
'name': weight_name,
'data': Tensor(weight_value)
}
save_obj.append(obj)

return save_obj, weight_map

@GeneratorError.check_except("Generator occurs an error when generating code statements.")
def generate(self):
"""
@@ -479,6 +548,9 @@ class Generator:
"""
self._form_bottom_submodule()
self._recursive_form_module()

ckpt_data_list, weight_map = self.generate_checkpoint()

CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)

outputs = [get_imported_module()]
@@ -494,7 +566,7 @@ class Generator:
report = report_generator.gen_report(formatted_code)
del self._global_context

return {"model": (formatted_code, report)}
return {"model": (formatted_code, report, ckpt_data_list, weight_map)}

def get_node_struct(self, node_identifier):
"""


+ 7
- 7
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -17,13 +17,13 @@
import copy
from collections import OrderedDict

from .node_struct import NodeStruct
from .scope_utils import Scope
from ..common.utils import get_dict_key_by_value
from .args_translator import ArgsTranslation
from ..common.code_fragment import ModuleFragment
from ..common.global_context import GlobalContext
from ..common.name_mgr import LocalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr


class ModuleStruct:


+ 5
- 5
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -17,11 +17,11 @@ from collections import OrderedDict

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler
from .scope_utils import Scope
from .args_translator import ArgsTranslation
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.common.exceptions import GeneratorError


class NodeStruct:


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,4 +16,4 @@

__all__ = ["ONNXToMindSporeMapper"]

from .base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper

+ 14
- 7
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -108,18 +108,21 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
try:
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights)

if "input_shape" in converted_params:
converted_params.pop("input_shape")
if "output_shape" in converted_params:
converted_params.pop("output_shape")
# set to converted_weights to enable weight migration
_ = weights_converter(weights=weights) if weights else dict()
converted_weights = weights_converter(weights=weights) if weights else dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=converter_name,
converted_params=converted_params,
raw_params=params,
weights=weights
weights=weights,
trainable_params=converted_weights
)

except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
@@ -148,6 +151,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights")
trainable_params = kwargs.get("trainable_params", dict())
if not op:
raise ValueError("Can not get MindSpore operation name.")
variable_slot = "var_0"
@@ -169,7 +173,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
@@ -177,11 +181,14 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _find_val_by_index(loc_index, values_dict):
"""Find value by location index of values_dict."""
def _find_val_by_index(loc_index, weights_list):
"""Find value by location index of weights_list."""
result = None
for idx, dict_val in enumerate(values_dict.values()):
if loc_index < 0:
return weights_list[loc_index].value

for idx, weight in enumerate(weights_list):
if idx == loc_index:
result = dict_val
result = weight.value
break
return result

+ 11
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -14,7 +14,6 @@
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


class BatchNormMapper(ONNXToMindSporeMapper):
@@ -36,8 +35,14 @@ class BatchNormMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
weights = kwargs['weights']
gamma = BatchNormMapper._find_val_by_index(0, weights)
beta = BatchNormMapper._find_val_by_index(1, weights)
moving_mean = BatchNormMapper._find_val_by_index(2, weights)
moving_variance = BatchNormMapper._find_val_by_index(3, weights)
return {
'gamma': gamma,
'beta': beta,
'moving_mean': moving_mean,
'moving_variance': moving_variance
}

+ 19
- 20
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -14,7 +14,6 @@
# ==============================================================================
"""Mapper module."""
import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string

@@ -42,7 +41,7 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from PyTorch to MindSpore"""
weights = kwargs['weights']
params = kwargs['params']
weight = weights['weight']
weight = ConvMapper._find_val_by_index(0, weights)
weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0])
if isinstance(params['dilations'], list):
dilation = tuple(params['dilations'])
@@ -76,11 +75,13 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from Tensorflow to MindSpore"""
weights = kwargs['weights']
params = kwargs['params']
# regex to find Conv weight
weight = list(weights.values())[0]
weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights)
if weight is None:
raise ValueError("Conv. Mapper cannot get the weight.")

has_bias = isinstance(bias, np.ndarray)

auto_pad = None
if params.get("auto_pad") is not None:
auto_pad = convert_bytes_string_to_string(params.get("auto_pad"))
@@ -119,18 +120,14 @@ class ConvMapper(ONNXToMindSporeMapper):
'padding': padding,
'pad_mode': pad_mode,
'dilation': dilation,
'group': params.get('group', 1)}
'group': params.get('group', 1),
'has_bias': has_bias
}

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
weight = kwargs['weights'].get('weight', 'empty')

if weight == 'empty': # is from tf
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d"

dim = weight.ndim - 2
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d"

@staticmethod
@@ -138,14 +135,16 @@ class ConvMapper(ONNXToMindSporeMapper):
weights = kwargs['weights']
params = kwargs['params']

if weights.get('weight', 'empty') == 'empty': # is from tf
return ConvMapper.convert_params_tf(params=params, weights=weights)
return ConvMapper.convert_params_torch(params=params, weights=weights)
return ConvMapper.convert_params_tf(params=params, weights=weights)

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights)

@staticmethod
def _convert_settings(**kwargs):
return Setting()
converted_weights = {'weight': weight}
if isinstance(bias, np.ndarray):
converted_weights['bias'] = bias

return converted_weights

+ 7
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -15,7 +15,6 @@
"""Mapper module."""
import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


class DenseMapper(ONNXToMindSporeMapper):
@@ -42,8 +41,10 @@ class DenseMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
weights = kwargs['weights']
weight = DenseMapper._find_val_by_index(0, weights)
bias = DenseMapper._find_val_by_index(1, weights)
return {
'weight': weight,
'bias': bias
}

+ 4
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -32,7 +32,9 @@ class MatMulMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
weight = MatMulMapper._find_val_by_index(0, weights)
return {'weight': weight}

@staticmethod
def _convert_settings(**kwargs):
@@ -56,8 +58,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

weight = list(weights.items())[0]
_, tensor = weight
tensor = MatMulMapper._find_val_by_index(0, weights)

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -15,7 +15,6 @@
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


def _padding_format_convert(padding: list):
@@ -49,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper):
weights = kwargs.get("weights")
params = kwargs.get("params")
mode = convert_bytes_string_to_string(params.get('mode', 'constant'))
pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist()
pads_onnx = params.get("pads") if params.get("pads") else PadMapper._find_val_by_index(0, weights).tolist()
if mode == 'constant' and params.get('value') is None:
if params.get('pads') or weights:
if isinstance(pads_onnx, list):
@@ -76,7 +75,3 @@ class PadMapper(ONNXToMindSporeMapper):
@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()

+ 103
- 10
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -13,12 +13,16 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import math

import numpy as np

from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords


class PoolMapper(ONNXToMindSporeMapper):
"""MaxPool mapper."""
"""Pool mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
@@ -35,12 +39,6 @@ class PoolMapper(ONNXToMindSporeMapper):
transformed_params = dict()
transformed_params["kernel_size"] = tuple(params['kernel_shape'])
transformed_params["stride"] = tuple(params['strides'])
if "pads" in params:
if sum(params['pads']) == 0 and not params.get('ceil_mode', None):
pad_mode = '\"valid\"'
else:
pad_mode = '\"same\"'
transformed_params["pad_mode"] = pad_mode

return transformed_params

@@ -49,5 +47,100 @@ class PoolMapper(ONNXToMindSporeMapper):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
def _get_ms_opt_shape(**kwargs):
"""Get output shape in MindSpore."""
params = kwargs['raw_params']
input_shape = params['input_shape']
kernel_shape = params['kernel_shape']
strides = params['strides']
dilations = params.get('dilations', (1, 1))
# For mindspore,
# output_shape[i] = ceil((input_shape[i] - ((kernel_shape[i] - 1) * dilations[i] + 1) + 1) / strides[i])
ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32),
((np.array(kernel_shape, dtype=np.float32) - 1) *
np.array(dilations, dtype=np.float32) + 1)) + 1,
np.array(strides, dtype=np.float32)).tolist()
ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape)
return ms_opt_shape_ceil

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())

ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs)
tensor_opt_shape = kwargs['raw_params']['output_shape']
tensor_ipt_shape = kwargs['raw_params']['input_shape']
kernel_shape = kwargs['raw_params']['kernel_shape']
dilations = kwargs['raw_params'].get('dilations', (1, 1))
strides = kwargs['raw_params']['strides']
onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):]

if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)):
return template, exchange_msg, outputs_list, outputs_mapping

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})"

init_template_pad, construct_template_pad, paddings = \
PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,
ms_opt_shape, variable_slot,
kernel_shape, dilations, strides)

template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template_pad, init_template],
TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template]
}
}

args['paddings'] = paddings

exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(),
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict()
}
}

return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,
ms_opt_shape, variable_slot, kernel_shape, dilations, strides):
"""Generate pad code in init and construct."""
onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):]
onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):]

if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)):
raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].")

# shape_diff[i] = (onnx_opt_shape[i] - 1)*strides[i] -
# (onnx_ipt_shape[i] - ((kernel_shape[i] - 1)*dilations[i] + 1))
shape_diff = np.subtract((np.array(onnx_opt_shape) - 1)*np.array(strides),
np.subtract(np.array(onnx_ipt_shape),
(np.array(kernel_shape) - 1)*np.array(dilations) + 1)).tolist()

zero_pad_single = (0, 0)
paddings = [zero_pad_single]
num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape)
for _ in range(num_zero_pads - 1):
paddings.append(zero_pad_single)

for axis_diff in shape_diff:
paddings.append((int(axis_diff//2), int(axis_diff//2 + axis_diff % 2)))

init_template_pad = f"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})"
construct_template_pad = f"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"

return init_template_pad, construct_template_pad, tuple(paddings)

+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -56,8 +56,7 @@ class AddMapper(ONNXToMindSporeMapper):
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

bias = list(weights.items())[0]
_, tensor = bias
tensor = AddMapper._find_val_by_index(0, weights)

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py View File

@@ -53,8 +53,7 @@ class MulMapper(ONNXToMindSporeMapper):
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

weight = list(weights.items())[0]
_, tensor = weight
tensor = MulMapper._find_val_by_index(0, weights)

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py View File

@@ -60,7 +60,7 @@ class ResizeMapper(ONNXToMindSporeMapper):
align_corners = True

# Get requested size for resize
size = list(weights.values())[-1][-2:].tolist()
size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist()

return {"size": tuple(size),
"align_corners": align_corners}


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py View File

@@ -48,7 +48,7 @@ class SliceMapper(ONNXToMindSporeMapper):
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
weights = list(kwargs.get("weights").values()) # start, end, axis
weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis
opt_shape = kwargs["raw_params"]["output_shape"]
if not weights:
raise ValueError("Cannot get required params from slice.")


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,4 +15,4 @@
"""Searcher of scope name."""
__all__ = ["generate_scope_name"]

from .searcher import generate_scope_name
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.searcher import generate_scope_name

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +16,8 @@

__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]

from .common import cal_matching_score
from .pattern import Pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern

BUILT_IN_PATTERN = dict()



+ 10
- 7
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,12 +17,15 @@ import copy
import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME
from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode

module_name_to_src = {}
used_module_name = dict()


+ 10
- 7
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,12 +16,15 @@
from queue import PriorityQueue
from typing import Dict, List

from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT
from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingError
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
ACCEPTABLE_RESULT_COUNT
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \
MAX_ITERATION_DEPTH, SATISFIED_SCORE
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
generate_pattern, find_built_in_pattern
from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError


def _is_satisfied(path):


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -17,7 +17,7 @@
__all__ = ["GraphFactory"]
from importlib import import_module

from .base import Graph
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph


class GraphFactory:


+ 3
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +15,8 @@
"""Define PyTorch graph node."""
import os

from .base import GraphNode
from ..constant import SEPARATOR_IN_SCOPE, NodeType
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_SCOPE, NodeType


class InputNode(GraphNode):
@@ -25,7 +25,6 @@ class InputNode(GraphNode):

Args:
input_shape: Input shape of module.

"""

def _get_arg_name(self, arg, variable_name):


+ 9
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -17,12 +17,12 @@ from importlib import import_module
from typing import Dict, NoReturn

from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph
from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight

NONE_SCOPE_OP = {
"onnx::Add": "Add",
@@ -126,7 +126,7 @@ class OnnxGraph(Graph):

self._shape_dict = model_data.node_output_shape_dict
for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()):
node_weight = {}
node_weights = list()
node.scope_name = scope_name_list[ind]
inputs = node.input_name_list
# check each input from node or tensors
@@ -135,8 +135,8 @@ class OnnxGraph(Graph):
tensor = model_data.tensors_dict[i]
t_name = tensor.name
t_value = tensor.to_array()
node_weight[t_name] = t_value
self._nodes_collection[node_name] = OnnxGraphNode(node, node_weight)
node_weights.append(NodeWeight(t_name, t_value))
self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights)
self._nodes_record[node_name] = node_name

for nd_ipt_name in node.precursor_onnx_node_dict:


+ 6
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,11 +15,11 @@
"""Define ONNX graph node."""
from importlib import import_module

from .base import GraphNode
from ..common.utils import is_converted
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode
from mindinsight.mindconverter.graph_based_converter.common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
from mindinsight.mindconverter.graph_based_converter.constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP


class OnnxGraphNode(GraphNode):
@@ -28,7 +28,7 @@ class OnnxGraphNode(GraphNode):

Args:
node (OnnxNode): OnnxNode Object.
weight (dict): Dictionary records weight and bias.
weight (list): List of recording node weights.
"""
_type_frozen = False
_module_name_frozen = False
@@ -227,7 +227,6 @@ class OnnxGraphNode(GraphNode):
Args:
src_arg (str): Original arg name.
tgt_arg (str): Target arg name.

"""
self._args_in_code[src_arg] = tgt_arg



+ 20
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -22,13 +22,13 @@ from typing import Union
import numpy as np

from mindinsight.mindconverter.common.log import logger as log
from ..common.utils import fetch_output_from_onnx_model
from ..common.global_context import GlobalContext
from .optimizer import OnnxSimplify
from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitError, ModelLoadingError
from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -128,7 +128,6 @@ class ParamsAttribute:
raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance.
node (onnx.NodeProto): Must pass the onnx.NodeProto instance
containing the same AttributeProto.

"""

def __init__(self, raw_attribute, node):
@@ -148,7 +147,6 @@ class ParamsAttribute:

Args:
attrs (onnx.AttributeProto): onnx.AttributeProto instance.

"""
if not attrs:
return
@@ -604,3 +602,18 @@ class OnnxDataLoader:
eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape)
self.dynamic_resize_node.append(nd_name)
self.eliminated_nodes += eliminated_nodes


class NodeWeight:
"""Node weight struct."""
def __init__(self, weight_name, weight_value):
self._weight_name = weight_name
self._weight_value = weight_value

@property
def name(self):
return self._weight_name

@property
def value(self):
return self._weight_value

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -18,7 +18,7 @@ from importlib import import_module

import numpy as np

from ..common.utils import fetch_output_from_onnx_model
from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model


class OnnxSimplify:


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -20,6 +20,7 @@ from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser
from mindinsight.mindconverter.common.exceptions import ModelNotSupportError


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""



+ 3
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@ import re
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupportError
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser
from mindinsight.mindconverter.common.exceptions import ModelNotSupportError


class TFGraphParser(GraphParser):


+ 6
- 2
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -89,7 +89,9 @@ class TestMappers:
'input': {'op_name': 'onnx::MaxPool',
'params': {'kernel_shape': [3, 3],
'pads': [1, 1, 1, 1],
'strides': [2, 2]},
'strides': [2, 2],
'input_shape': (1, 3, 224, 224),
'output_shape': (1, 3, 112, 112)},
'weights': dict()},
'expected_output': {'converter_name': 'nn.MaxPool2d',
'converted_params': {'kernel_size': (3, 3),
@@ -100,7 +102,9 @@ class TestMappers:
'input': {'op_name': 'onnx::AveragePool',
'params': {'kernel_shape': [3, 3],
'pads': [1, 1, 1, 1],
'strides': [2, 2]},
'strides': [2, 2],
'input_shape': (1, 3, 224, 224),
'output_shape': (1, 3, 112, 112)},
'weights': dict()},
'expected_output': {'converter_name': 'nn.AvgPool2d',
'converted_params': {'kernel_size': (3, 3),


+ 26
- 0
tests/utils/mindspore/train/serialization.py View File

@@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/serialization.py."""


def save_checkpoint(trainable_weights, ckpt_file_name):
"""
Mock save_checkpoint.

Args:
trainable_weights (list): List of weights.
ckpt_file_name (str): Path to save checkpoint file.
"""
return len(trainable_weights), ckpt_file_name

Loading…
Cancel
Save