| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -301,6 +301,8 @@ class SourceFilesSaveError(MindConverterException): | |||
| NODE_INPUT_TYPE_NOT_SUPPORT = 1 | |||
| SCRIPT_GENERATE_FAIL = 2 | |||
| REPORT_GENERATE_FAIL = 3 | |||
| CKPT_GENERATE_FAIL = 4 | |||
| MAP_GENERATE_FAIL = 5 | |||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| @@ -315,6 +317,8 @@ class SourceFilesSaveError(MindConverterException): | |||
| except_source = (NodeInputTypeNotSupportError, | |||
| ScriptGenerationError, | |||
| ReportGenerationError, | |||
| CheckPointGenerationError, | |||
| WeightMapGenerationError, | |||
| IOError, cls) | |||
| return except_source | |||
| @@ -437,6 +441,32 @@ class ReportGenerationError(SourceFilesSaveError): | |||
| return ZeroDivisionError, cls | |||
| class CheckPointGenerationError(SourceFilesSaveError): | |||
| """The checkpoint generate fail error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value | |||
| def __init__(self, msg): | |||
| super(CheckPointGenerationError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exceptions below.""" | |||
| return cls | |||
| class WeightMapGenerationError(SourceFilesSaveError): | |||
| """The weight names map generate fail error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value | |||
| def __init__(self, msg): | |||
| super(WeightMapGenerationError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exception below.""" | |||
| return cls | |||
| class SubGraphSearchingError(MindConverterException): | |||
| """Sub-graph searching exception.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,5 +15,5 @@ | |||
| """Graph based scripts converter definition.""" | |||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | |||
| from .framework import graph_based_converter_pytorch_to_ms | |||
| from .framework import graph_based_converter_tf_to_ms | |||
| from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_pytorch_to_ms | |||
| from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_tf_to_ms | |||
| @@ -191,18 +191,23 @@ class CodeFragment(Fragment): | |||
| """ | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||
| trainable_params=None): | |||
| trainable_params=None, trainable_weights=None): | |||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | |||
| input_shape=input_shape, output_shape=output_shape, | |||
| settings=settings) | |||
| self._trainable_params = dict() # External weights, like Matmul. | |||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | |||
| self._trainable_weights = trainable_weights | |||
| @property | |||
| def trainable_params(self): | |||
| """Return the trainable parameters.""" | |||
| return self._trainable_params | |||
| @property | |||
| def trainable_weights(self): | |||
| return self._trainable_weights | |||
| class ModuleFragment(Fragment): | |||
| """Manage module type code variables.""" | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================== | |||
| """Define GlobalContext class to save required resources during whole conversion procedure.""" | |||
| from collections import OrderedDict | |||
| from .outputs import OutputStorage | |||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import OutputStorage | |||
| class Singleton(type): | |||
| @@ -13,16 +13,21 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define common utils.""" | |||
| import json | |||
| import os | |||
| import stat | |||
| from importlib import import_module | |||
| from importlib.util import find_spec | |||
| from typing import List, Tuple, Mapping | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | |||
| UnknownModelError, CheckPointGenerationError, WeightMapGenerationError | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | |||
| FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX | |||
| from mindspore.train.serialization import save_checkpoint | |||
| def is_converted(operation: str): | |||
| """ | |||
| @@ -96,7 +101,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||
| code_lines (dict): Code lines. | |||
| out_folder (str): Output folder. | |||
| report_folder (str): Report output folder. | |||
| """ | |||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||
| modes = stat.S_IRUSR | stat.S_IWUSR | |||
| @@ -114,7 +118,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||
| os.makedirs(report_folder, modes_usr) | |||
| for file_name in code_lines: | |||
| code, report = code_lines[file_name] | |||
| code, report, trainable_weights, weight_map = code_lines[file_name] | |||
| code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) | |||
| report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) | |||
| try: | |||
| @@ -133,6 +137,31 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||
| except (IOError, FileExistsError) as error: | |||
| raise ReportGenerationError(str(error)) | |||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||
| try: | |||
| if os.path.exists(ckpt_file_path): | |||
| raise CheckPointGenerationError("Checkpoint file with the same name already exists.") | |||
| save_checkpoint(trainable_weights, ckpt_file_path) | |||
| except TypeError as error: | |||
| raise CheckPointGenerationError(str(error)) | |||
| weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json")) | |||
| try: | |||
| if os.path.exists(weight_map_path): | |||
| raise WeightMapGenerationError("Weight map file with the same name already exists.") | |||
| with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f: | |||
| weight_map_json = {f"{model_name}": weight_map} | |||
| json.dump(weight_map_json, map_f) | |||
| except (IOError, FileExistsError) as error: | |||
| raise WeightMapGenerationError(str(error)) | |||
| def onnx_satisfied(): | |||
| """Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation.""" | |||
| if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"): | |||
| return False | |||
| return True | |||
| def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | |||
| newest_ver_limited: str = ""): | |||
| @@ -220,6 +249,7 @@ def reset_init_or_construct(template, variable_slot, new_data, scope): | |||
| template[variable_slot][scope] += new_data | |||
| return template | |||
| def replace_string_in_list(str_list: list, original_str: str, target_str: str): | |||
| """ | |||
| Replace a string in a list by provided string. | |||
| @@ -41,6 +41,7 @@ UNKNOWN_DIM_VAL = "unk__001" | |||
| ONNX_MIN_VER = "1.8.0" | |||
| TF2ONNX_MIN_VER = "1.7.1" | |||
| ONNXRUNTIME_MIN_VER = "1.5.2" | |||
| ONNXOPTIMIZER_MIN_VER = "0.1.2" | |||
| @unique | |||
| @@ -21,10 +21,10 @@ from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | |||
| save_code_file_and_report, get_framework_type | |||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | |||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER | |||
| from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||
| @@ -53,6 +53,18 @@ parser.add_argument("--report", type=str, required=False, | |||
| help="Generated reports output folder path.") | |||
| def onnx_lib_version_satisfied(): | |||
| """Check onnx libs version whether is satisfied.""" | |||
| onnx = import_module("onnx") | |||
| ort = import_module("onnxruntime") | |||
| optimizer = import_module("onnxoptimizer.version") | |||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER): | |||
| return False | |||
| return True | |||
| def torch_installation_validation(func): | |||
| """ | |||
| Validate args of func. | |||
| @@ -68,26 +80,33 @@ def torch_installation_validation(func): | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): | |||
| error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) " | |||
| f"are required when using graph based " | |||
| f"scripts converter, and PyTorch version must " | |||
| f"be consisted with model generation runtime.") | |||
| error_info = None | |||
| if graph_path.endswith('.onnx'): | |||
| if not onnx_satisfied(): | |||
| error_info = f"onnx(>={ONNX_MIN_VER}, onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ | |||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ | |||
| f"are required when using graph based scripts converter." | |||
| else: | |||
| if not find_spec("torch") or not onnx_satisfied(): | |||
| error_info = f"PyTorch, onnx(>={ONNX_MIN_VER}), " \ | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ | |||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ | |||
| f"are required when using graph based " \ | |||
| f"scripts converter, and PyTorch version must " \ | |||
| f"be consisted with model generation runtime." | |||
| if error_info: | |||
| error = RuntimeIntegrityError(error_info) | |||
| log.error(error) | |||
| log_console.error("\n") | |||
| log_console.error(str(error)) | |||
| log_console.error("\n") | |||
| sys.exit(0) | |||
| onnx = import_module("onnx") | |||
| ort = import_module("onnxruntime") | |||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER): | |||
| if not onnx_lib_version_satisfied(): | |||
| error = RuntimeIntegrityError( | |||
| f"onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||
| f"onnx(>={ONNX_MIN_VER}), " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " | |||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) are required when using graph " | |||
| f"based scripts converter for Pytorch conversion." | |||
| ) | |||
| log.error(error) | |||
| @@ -128,11 +147,11 @@ def tf_installation_validation(func): | |||
| output_folder: str, report_folder: str = None, | |||
| input_nodes: str = None, output_nodes: str = None): | |||
| # Check whether tensorflow is installed. | |||
| if not _check_tf_installation() or not find_spec("tf2onnx") \ | |||
| or not find_spec("onnx") or not find_spec("onnxruntime"): | |||
| if not _check_tf_installation() or not onnx_satisfied(): | |||
| error = RuntimeIntegrityError( | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " | |||
| f"are required when using graph " | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(error) | |||
| @@ -141,15 +160,14 @@ def tf_installation_validation(func): | |||
| log_console.error("\n") | |||
| sys.exit(0) | |||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | |||
| ort = import_module("onnxruntime") | |||
| tf2onnx = import_module("tf2onnx") | |||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | |||
| or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): | |||
| if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ | |||
| or not onnx_lib_version_satisfied(): | |||
| error = RuntimeIntegrityError( | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " | |||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " | |||
| f"are required when using graph " | |||
| f"based scripts converter for TensorFlow conversion." | |||
| ) | |||
| log.error(error) | |||
| @@ -258,12 +276,12 @@ def main_graph_base_converter(file_config): | |||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | |||
| if frame_type == FrameworkType.PYTORCH.value: | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||
| sample_shape=file_config['shape'], | |||
| input_nodes=file_config['input_nodes'], | |||
| output_nodes=file_config['output_nodes'], | |||
| input_nodes=file_config['input_nodes'] if file_config['input_nodes'] | |||
| else 'input.1', | |||
| output_nodes=file_config['output_nodes'] if file_config['output_nodes'] | |||
| else '', | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| elif frame_type == FrameworkType.TENSORFLOW.value: | |||
| @@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"] | |||
| import re | |||
| import copy | |||
| from .generator import Generator, CodeStruct | |||
| from ..common.code_fragment import CodeFragment, NewFragment | |||
| from ..common.outputs import NodeOutputManager | |||
| from ..constant import ExchangeMessageKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||
| def _tf_model_node_name_reformat(node, node_name): | |||
| @@ -16,6 +16,7 @@ | |||
| import copy | |||
| from collections import OrderedDict | |||
| from mindspore import Tensor | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||
| @@ -28,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO | |||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | |||
| FIRST_LEVEL_INDENT, get_imported_module | |||
| FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID | |||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | |||
| @@ -469,6 +470,74 @@ class Generator: | |||
| """Return all ModuleStructs in this model.""" | |||
| return self._module_struct_collections | |||
| def generate_weight_scope_name(self, node): | |||
| """Generate weight scope name for checkpoint.""" | |||
| replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name | |||
| scope_list = self.node_structs[node].scope.scope_list | |||
| ms_var_name = self.node_structs[node].ms_var_name | |||
| weight_scope_name = None | |||
| for scope in scope_list[1:]: | |||
| replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]) | |||
| if replaced_module: | |||
| scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module) | |||
| if not weight_scope_name: | |||
| weight_scope_name = scope | |||
| else: | |||
| weight_scope_name = '.'.join((weight_scope_name, scope)) | |||
| if not weight_scope_name: | |||
| weight_scope_name = ms_var_name | |||
| else: | |||
| weight_scope_name = '.'.join((weight_scope_name, ms_var_name)) | |||
| return weight_scope_name.lower() | |||
| def generate_checkpoint(self): | |||
| """Generate checkpoint.""" | |||
| trainable_weights_dict = dict() | |||
| weight_map = list() | |||
| for node_name, node_inst in self.node_structs.items(): | |||
| if node_inst.fragment.exchange_msg['var_0']['trainable_params']: | |||
| weights_scope_name = self.generate_weight_scope_name(node_name) | |||
| onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] | |||
| for idx, (weight_key, weight_value) in \ | |||
| enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): | |||
| weight_name = '.'.join((weights_scope_name, weight_key)) | |||
| weight_shape = Tensor(weight_value).shape | |||
| data_type = Tensor(weight_value).dtype | |||
| trainable_weights_dict[weight_name] = weight_value | |||
| onnx_weight_name = onnx_weight_inst[idx].name | |||
| onnx_weight_shape = onnx_weight_inst[idx].value.shape | |||
| onnx_data_type = onnx_weight_inst[idx].value.dtype | |||
| weight_map.append( | |||
| { | |||
| 'converted_weight': { | |||
| 'name': weight_name, | |||
| 'shape': weight_shape, | |||
| 'data_type': str(data_type) | |||
| }, | |||
| 'source_weight': { | |||
| 'name': onnx_weight_name, | |||
| 'shape': onnx_weight_shape, | |||
| 'data_type': str(onnx_data_type) | |||
| } | |||
| } | |||
| ) | |||
| save_obj = list() | |||
| for weight_name, weight_value in trainable_weights_dict.items(): | |||
| obj = { | |||
| 'name': weight_name, | |||
| 'data': Tensor(weight_value) | |||
| } | |||
| save_obj.append(obj) | |||
| return save_obj, weight_map | |||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | |||
| def generate(self): | |||
| """ | |||
| @@ -479,6 +548,9 @@ class Generator: | |||
| """ | |||
| self._form_bottom_submodule() | |||
| self._recursive_form_module() | |||
| ckpt_data_list, weight_map = self.generate_checkpoint() | |||
| CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||
| outputs = [get_imported_module()] | |||
| @@ -494,7 +566,7 @@ class Generator: | |||
| report = report_generator.gen_report(formatted_code) | |||
| del self._global_context | |||
| return {"model": (formatted_code, report)} | |||
| return {"model": (formatted_code, report, ckpt_data_list, weight_map)} | |||
| def get_node_struct(self, node_identifier): | |||
| """ | |||
| @@ -17,13 +17,13 @@ | |||
| import copy | |||
| from collections import OrderedDict | |||
| from .node_struct import NodeStruct | |||
| from .scope_utils import Scope | |||
| from ..common.utils import get_dict_key_by_value | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import ModuleFragment | |||
| from ..common.global_context import GlobalContext | |||
| from ..common.name_mgr import LocalVarNameMgr | |||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value | |||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr | |||
| class ModuleStruct: | |||
| @@ -17,11 +17,11 @@ from collections import OrderedDict | |||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||
| from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler | |||
| from .scope_utils import Scope | |||
| from .args_translator import ArgsTranslation | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ...common.exceptions import GeneratorError | |||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||
| class NodeStruct: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,4 +16,4 @@ | |||
| __all__ = ["ONNXToMindSporeMapper"] | |||
| from .base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| @@ -108,18 +108,21 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| try: | |||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | |||
| converted_params = params_converter(params=params, weights=weights) | |||
| if "input_shape" in converted_params: | |||
| converted_params.pop("input_shape") | |||
| if "output_shape" in converted_params: | |||
| converted_params.pop("output_shape") | |||
| # set to converted_weights to enable weight migration | |||
| _ = weights_converter(weights=weights) if weights else dict() | |||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||
| code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( | |||
| operation=converter_name, | |||
| converted_params=converted_params, | |||
| raw_params=params, | |||
| weights=weights | |||
| weights=weights, | |||
| trainable_params=converted_weights | |||
| ) | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| @@ -148,6 +151,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| weights = kwargs.get("weights") | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| @@ -169,7 +173,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| @@ -177,11 +181,14 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| def _find_val_by_index(loc_index, values_dict): | |||
| """Find value by location index of values_dict.""" | |||
| def _find_val_by_index(loc_index, weights_list): | |||
| """Find value by location index of weights_list.""" | |||
| result = None | |||
| for idx, dict_val in enumerate(values_dict.values()): | |||
| if loc_index < 0: | |||
| return weights_list[loc_index].value | |||
| for idx, weight in enumerate(weights_list): | |||
| if idx == loc_index: | |||
| result = dict_val | |||
| result = weight.value | |||
| break | |||
| return result | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class BatchNormMapper(ONNXToMindSporeMapper): | |||
| @@ -36,8 +35,14 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| weights = kwargs['weights'] | |||
| gamma = BatchNormMapper._find_val_by_index(0, weights) | |||
| beta = BatchNormMapper._find_val_by_index(1, weights) | |||
| moving_mean = BatchNormMapper._find_val_by_index(2, weights) | |||
| moving_variance = BatchNormMapper._find_val_by_index(3, weights) | |||
| return { | |||
| 'gamma': gamma, | |||
| 'beta': beta, | |||
| 'moving_mean': moving_mean, | |||
| 'moving_variance': moving_variance | |||
| } | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| @@ -42,7 +41,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| """Convert params from PyTorch to MindSpore""" | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| weight = weights['weight'] | |||
| weight = ConvMapper._find_val_by_index(0, weights) | |||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | |||
| if isinstance(params['dilations'], list): | |||
| dilation = tuple(params['dilations']) | |||
| @@ -76,11 +75,13 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| """Convert params from Tensorflow to MindSpore""" | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| # regex to find Conv weight | |||
| weight = list(weights.values())[0] | |||
| weight = ConvMapper._find_val_by_index(0, weights) | |||
| bias = ConvMapper._find_val_by_index(1, weights) | |||
| if weight is None: | |||
| raise ValueError("Conv. Mapper cannot get the weight.") | |||
| has_bias = isinstance(bias, np.ndarray) | |||
| auto_pad = None | |||
| if params.get("auto_pad") is not None: | |||
| auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) | |||
| @@ -119,18 +120,14 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| 'padding': padding, | |||
| 'pad_mode': pad_mode, | |||
| 'dilation': dilation, | |||
| 'group': params.get('group', 1)} | |||
| 'group': params.get('group', 1), | |||
| 'has_bias': has_bias | |||
| } | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| weight = kwargs['weights'].get('weight', 'empty') | |||
| if weight == 'empty': # is from tf | |||
| kernel_size = kwargs['params'].get('kernel_shape') | |||
| dim = len(kernel_size) | |||
| return f"nn.Conv{dim}d" | |||
| dim = weight.ndim - 2 | |||
| kernel_size = kwargs['params'].get('kernel_shape') | |||
| dim = len(kernel_size) | |||
| return f"nn.Conv{dim}d" | |||
| @staticmethod | |||
| @@ -138,14 +135,16 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| if weights.get('weight', 'empty') == 'empty': # is from tf | |||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||
| return ConvMapper.convert_params_torch(params=params, weights=weights) | |||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| weights = kwargs['weights'] | |||
| weight = ConvMapper._find_val_by_index(0, weights) | |||
| bias = ConvMapper._find_val_by_index(1, weights) | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| converted_weights = {'weight': weight} | |||
| if isinstance(bias, np.ndarray): | |||
| converted_weights['bias'] = bias | |||
| return converted_weights | |||
| @@ -15,7 +15,6 @@ | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class DenseMapper(ONNXToMindSporeMapper): | |||
| @@ -42,8 +41,10 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| weights = kwargs['weights'] | |||
| weight = DenseMapper._find_val_by_index(0, weights) | |||
| bias = DenseMapper._find_val_by_index(1, weights) | |||
| return { | |||
| 'weight': weight, | |||
| 'bias': bias | |||
| } | |||
| @@ -32,7 +32,9 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| weights = kwargs['weights'] | |||
| weight = MatMulMapper._find_val_by_index(0, weights) | |||
| return {'weight': weight} | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| @@ -56,8 +58,7 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| weight = list(weights.items())[0] | |||
| _, tensor = weight | |||
| tensor = MatMulMapper._find_val_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| @@ -15,7 +15,6 @@ | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| def _padding_format_convert(padding: list): | |||
| @@ -49,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| weights = kwargs.get("weights") | |||
| params = kwargs.get("params") | |||
| mode = convert_bytes_string_to_string(params.get('mode', 'constant')) | |||
| pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() | |||
| pads_onnx = params.get("pads") if params.get("pads") else PadMapper._find_val_by_index(0, weights).tolist() | |||
| if mode == 'constant' and params.get('value') is None: | |||
| if params.get('pads') or weights: | |||
| if isinstance(pads_onnx, list): | |||
| @@ -76,7 +75,3 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| @@ -13,12 +13,16 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import math | |||
| import numpy as np | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| class PoolMapper(ONNXToMindSporeMapper): | |||
| """MaxPool mapper.""" | |||
| """Pool mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| @@ -35,12 +39,6 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| transformed_params = dict() | |||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | |||
| transformed_params["stride"] = tuple(params['strides']) | |||
| if "pads" in params: | |||
| if sum(params['pads']) == 0 and not params.get('ceil_mode', None): | |||
| pad_mode = '\"valid\"' | |||
| else: | |||
| pad_mode = '\"same\"' | |||
| transformed_params["pad_mode"] = pad_mode | |||
| return transformed_params | |||
| @@ -49,5 +47,100 @@ class PoolMapper(ONNXToMindSporeMapper): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| def _get_ms_opt_shape(**kwargs): | |||
| """Get output shape in MindSpore.""" | |||
| params = kwargs['raw_params'] | |||
| input_shape = params['input_shape'] | |||
| kernel_shape = params['kernel_shape'] | |||
| strides = params['strides'] | |||
| dilations = params.get('dilations', (1, 1)) | |||
| # For mindspore, | |||
| # output_shape[i] = ceil((input_shape[i] - ((kernel_shape[i] - 1) * dilations[i] + 1) + 1) / strides[i]) | |||
| ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32), | |||
| ((np.array(kernel_shape, dtype=np.float32) - 1) * | |||
| np.array(dilations, dtype=np.float32) + 1)) + 1, | |||
| np.array(strides, dtype=np.float32)).tolist() | |||
| ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape) | |||
| return ms_opt_shape_ceil | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs) | |||
| tensor_opt_shape = kwargs['raw_params']['output_shape'] | |||
| tensor_ipt_shape = kwargs['raw_params']['input_shape'] | |||
| kernel_shape = kwargs['raw_params']['kernel_shape'] | |||
| dilations = kwargs['raw_params'].get('dilations', (1, 1)) | |||
| strides = kwargs['raw_params']['strides'] | |||
| onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] | |||
| if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)): | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})" | |||
| init_template_pad, construct_template_pad, paddings = \ | |||
| PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, | |||
| ms_opt_shape, variable_slot, | |||
| kernel_shape, dilations, strides) | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template_pad, init_template], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template] | |||
| } | |||
| } | |||
| args['paddings'] = paddings | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(), | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() | |||
| } | |||
| } | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @staticmethod | |||
| def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, | |||
| ms_opt_shape, variable_slot, kernel_shape, dilations, strides): | |||
| """Generate pad code in init and construct.""" | |||
| onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] | |||
| onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):] | |||
| if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)): | |||
| raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].") | |||
| # shape_diff[i] = (onnx_opt_shape[i] - 1)*strides[i] - | |||
| # (onnx_ipt_shape[i] - ((kernel_shape[i] - 1)*dilations[i] + 1)) | |||
| shape_diff = np.subtract((np.array(onnx_opt_shape) - 1)*np.array(strides), | |||
| np.subtract(np.array(onnx_ipt_shape), | |||
| (np.array(kernel_shape) - 1)*np.array(dilations) + 1)).tolist() | |||
| zero_pad_single = (0, 0) | |||
| paddings = [zero_pad_single] | |||
| num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape) | |||
| for _ in range(num_zero_pads - 1): | |||
| paddings.append(zero_pad_single) | |||
| for axis_diff in shape_diff: | |||
| paddings.append((int(axis_diff//2), int(axis_diff//2 + axis_diff % 2))) | |||
| init_template_pad = f"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})" | |||
| construct_template_pad = f"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||
| return init_template_pad, construct_template_pad, tuple(paddings) | |||
| @@ -56,8 +56,7 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| bias = list(weights.items())[0] | |||
| _, tensor = bias | |||
| tensor = AddMapper._find_val_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| @@ -53,8 +53,7 @@ class MulMapper(ONNXToMindSporeMapper): | |||
| if not weights: | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| weight = list(weights.items())[0] | |||
| _, tensor = weight | |||
| tensor = MulMapper._find_val_by_index(0, weights) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| @@ -60,7 +60,7 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||
| align_corners = True | |||
| # Get requested size for resize | |||
| size = list(weights.values())[-1][-2:].tolist() | |||
| size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist() | |||
| return {"size": tuple(size), | |||
| "align_corners": align_corners} | |||
| @@ -48,7 +48,7 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = list(kwargs.get("weights").values()) # start, end, axis | |||
| weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis | |||
| opt_shape = kwargs["raw_params"]["output_shape"] | |||
| if not weights: | |||
| raise ValueError("Cannot get required params from slice.") | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,4 +15,4 @@ | |||
| """Searcher of scope name.""" | |||
| __all__ = ["generate_scope_name"] | |||
| from .searcher import generate_scope_name | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.searcher import generate_scope_name | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,8 +16,8 @@ | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | |||
| from .common import cal_matching_score | |||
| from .pattern import Pattern | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | |||
| BUILT_IN_PATTERN = dict() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,12 +17,15 @@ import copy | |||
| import uuid | |||
| from typing import Dict, List, Callable, Union | |||
| from collections import OrderedDict | |||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | |||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||
| from .pattern import Pattern, scope_name_mapping | |||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | |||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | |||
| from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ | |||
| MAX_OUT_DEGREE, cal_matching_score | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ | |||
| is_built_in_pattern | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | |||
| pattern_fuzzy_matching | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode | |||
| module_name_to_src = {} | |||
| used_module_name = dict() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,12 +16,15 @@ | |||
| from queue import PriorityQueue | |||
| from typing import Dict, List | |||
| from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | |||
| from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||
| from ..common.global_context import GlobalContext | |||
| from ..third_party_graph.onnx_utils import BaseNode | |||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | |||
| from ...common.exceptions import SubGraphSearchingError | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | |||
| ACCEPTABLE_RESULT_COUNT | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ | |||
| MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | |||
| generate_pattern, find_built_in_pattern | |||
| from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError | |||
| def _is_satisfied(path): | |||
| @@ -17,7 +17,7 @@ | |||
| __all__ = ["GraphFactory"] | |||
| from importlib import import_module | |||
| from .base import Graph | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | |||
| class GraphFactory: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,8 +15,8 @@ | |||
| """Define PyTorch graph node.""" | |||
| import os | |||
| from .base import GraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, NodeType | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_SCOPE, NodeType | |||
| class InputNode(GraphNode): | |||
| @@ -25,7 +25,6 @@ class InputNode(GraphNode): | |||
| Args: | |||
| input_shape: Input shape of module. | |||
| """ | |||
| def _get_arg_name(self, arg, variable_name): | |||
| @@ -17,12 +17,12 @@ from importlib import import_module | |||
| from typing import Dict, NoReturn | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| from .pytorch_graph_parser import PyTorchGraphParser | |||
| from .tf_graph_parser import TFGraphParser | |||
| from .onnx_utils import OnnxDataLoader | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight | |||
| NONE_SCOPE_OP = { | |||
| "onnx::Add": "Add", | |||
| @@ -126,7 +126,7 @@ class OnnxGraph(Graph): | |||
| self._shape_dict = model_data.node_output_shape_dict | |||
| for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): | |||
| node_weight = {} | |||
| node_weights = list() | |||
| node.scope_name = scope_name_list[ind] | |||
| inputs = node.input_name_list | |||
| # check each input from node or tensors | |||
| @@ -135,8 +135,8 @@ class OnnxGraph(Graph): | |||
| tensor = model_data.tensors_dict[i] | |||
| t_name = tensor.name | |||
| t_value = tensor.to_array() | |||
| node_weight[t_name] = t_value | |||
| self._nodes_collection[node_name] = OnnxGraphNode(node, node_weight) | |||
| node_weights.append(NodeWeight(t_name, t_value)) | |||
| self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights) | |||
| self._nodes_record[node_name] = node_name | |||
| for nd_ipt_name in node.precursor_onnx_node_dict: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,11 +15,11 @@ | |||
| """Define ONNX graph node.""" | |||
| from importlib import import_module | |||
| from .base import GraphNode | |||
| from ..common.utils import is_converted | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import is_converted | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP | |||
| class OnnxGraphNode(GraphNode): | |||
| @@ -28,7 +28,7 @@ class OnnxGraphNode(GraphNode): | |||
| Args: | |||
| node (OnnxNode): OnnxNode Object. | |||
| weight (dict): Dictionary records weight and bias. | |||
| weight (list): List of recording node weights. | |||
| """ | |||
| _type_frozen = False | |||
| _module_name_frozen = False | |||
| @@ -227,7 +227,6 @@ class OnnxGraphNode(GraphNode): | |||
| Args: | |||
| src_arg (str): Original arg name. | |||
| tgt_arg (str): Target arg name. | |||
| """ | |||
| self._args_in_code[src_arg] = tgt_arg | |||
| @@ -22,13 +22,13 @@ from typing import Union | |||
| import numpy as np | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| from ..common.global_context import GlobalContext | |||
| from .optimizer import OnnxSimplify | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify | |||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| from ...common.exceptions import GraphInitError, ModelLoadingError | |||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError | |||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | |||
| @@ -128,7 +128,6 @@ class ParamsAttribute: | |||
| raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance. | |||
| node (onnx.NodeProto): Must pass the onnx.NodeProto instance | |||
| containing the same AttributeProto. | |||
| """ | |||
| def __init__(self, raw_attribute, node): | |||
| @@ -148,7 +147,6 @@ class ParamsAttribute: | |||
| Args: | |||
| attrs (onnx.AttributeProto): onnx.AttributeProto instance. | |||
| """ | |||
| if not attrs: | |||
| return | |||
| @@ -604,3 +602,18 @@ class OnnxDataLoader: | |||
| eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) | |||
| self.dynamic_resize_node.append(nd_name) | |||
| self.eliminated_nodes += eliminated_nodes | |||
| class NodeWeight: | |||
| """Node weight struct.""" | |||
| def __init__(self, weight_name, weight_value): | |||
| self._weight_name = weight_name | |||
| self._weight_value = weight_value | |||
| @property | |||
| def name(self): | |||
| return self._weight_name | |||
| @property | |||
| def value(self): | |||
| return self._weight_value | |||
| @@ -18,7 +18,7 @@ from importlib import import_module | |||
| import numpy as np | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | |||
| class OnnxSimplify: | |||
| @@ -20,6 +20,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | |||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -18,8 +18,8 @@ import re | |||
| from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from ...common.exceptions import ModelNotSupportError | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | |||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||
| class TFGraphParser(GraphParser): | |||
| @@ -89,7 +89,9 @@ class TestMappers: | |||
| 'input': {'op_name': 'onnx::MaxPool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| 'pads': [1, 1, 1, 1], | |||
| 'strides': [2, 2]}, | |||
| 'strides': [2, 2], | |||
| 'input_shape': (1, 3, 224, 224), | |||
| 'output_shape': (1, 3, 112, 112)}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.MaxPool2d', | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| @@ -100,7 +102,9 @@ class TestMappers: | |||
| 'input': {'op_name': 'onnx::AveragePool', | |||
| 'params': {'kernel_shape': [3, 3], | |||
| 'pads': [1, 1, 1, 1], | |||
| 'strides': [2, 2]}, | |||
| 'strides': [2, 2], | |||
| 'input_shape': (1, 3, 224, 224), | |||
| 'output_shape': (1, 3, 112, 112)}, | |||
| 'weights': dict()}, | |||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | |||
| 'converted_params': {'kernel_size': (3, 3), | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/train/serialization.py.""" | |||
| def save_checkpoint(trainable_weights, ckpt_file_name): | |||
| """ | |||
| Mock save_checkpoint. | |||
| Args: | |||
| trainable_weights (list): List of weights. | |||
| ckpt_file_name (str): Path to save checkpoint file. | |||
| """ | |||
| return len(trainable_weights), ckpt_file_name | |||