| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -318,8 +318,8 @@ parser.add_argument( | |||||
| action=ModelFileAction, | action=ModelFileAction, | ||||
| required=False, | required=False, | ||||
| help=""" | help=""" | ||||
| PyTorch .pth or Tensorflow .pb model file path to use graph | |||||
| based schema to do script generation. When | |||||
| PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path | |||||
| is expected to do script generation based on graph schema. When | |||||
| `--in_file` and `--model_file` are both provided, | `--in_file` and `--model_file` are both provided, | ||||
| use AST schema as default. | use AST schema as default. | ||||
| """) | """) | ||||
| @@ -20,6 +20,7 @@ SEPARATOR_IN_SCOPE = "/" | |||||
| SEPARATOR_BTW_NAME_AND_ID = "_" | SEPARATOR_BTW_NAME_AND_ID = "_" | ||||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" | SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" | ||||
| LINK_IN_SCOPE = "-" | LINK_IN_SCOPE = "-" | ||||
| LINK_IN_WEIGHT_NAME = "." | |||||
| LEFT_BUCKET = "[" | LEFT_BUCKET = "[" | ||||
| RIGHT_BUCKET = "]" | RIGHT_BUCKET = "]" | ||||
| @@ -127,6 +128,12 @@ class FrameworkType(Enum): | |||||
| UNKNOWN = 2 | UNKNOWN = 2 | ||||
| @unique | |||||
| class WeightType(Enum): | |||||
| PARAMETER = 0 | |||||
| COMMON = 1 | |||||
| def get_imported_module(): | def get_imported_module(): | ||||
| """ | """ | ||||
| Generate imported module header. | Generate imported module header. | ||||
| @@ -137,5 +144,5 @@ def get_imported_module(): | |||||
| return f"import numpy as np{NEW_LINE}" \ | return f"import numpy as np{NEW_LINE}" \ | ||||
| f"import mindspore{NEW_LINE}" \ | f"import mindspore{NEW_LINE}" \ | ||||
| f"from mindspore import nn{NEW_LINE}" \ | f"from mindspore import nn{NEW_LINE}" \ | ||||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||||
| f"from mindspore import Tensor, Parameter{NEW_LINE}" \ | |||||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | f"from mindspore.ops import operations as P{NEW_LINE * 3}" | ||||
| @@ -29,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO | |||||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | ||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | ||||
| FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID | |||||
| FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME | |||||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | ||||
| @@ -502,12 +502,17 @@ class Generator: | |||||
| if node_inst.fragment.exchange_msg['var_0']['trainable_params']: | if node_inst.fragment.exchange_msg['var_0']['trainable_params']: | ||||
| weights_scope_name = self.generate_weight_scope_name(node_name) | weights_scope_name = self.generate_weight_scope_name(node_name) | ||||
| onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] | onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] | ||||
| for idx, (weight_key, weight_value) in \ | |||||
| for idx, (weight_key, weight_value_object) in \ | |||||
| enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): | enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): | ||||
| weight_name = '.'.join((weights_scope_name, weight_key)) | |||||
| weight_shape = Tensor(weight_value).shape | |||||
| data_type = Tensor(weight_value).dtype | |||||
| trainable_weights_dict[weight_name] = weight_value | |||||
| value_type = weight_value_object.get('type', WeightType.COMMON.value) | |||||
| value_data = weight_value_object['data'] | |||||
| if value_type == WeightType.PARAMETER.value: | |||||
| weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key)) | |||||
| else: | |||||
| weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key)) | |||||
| weight_shape = Tensor(value_data).shape | |||||
| data_type = Tensor(value_data).dtype | |||||
| trainable_weights_dict[weight_name] = value_data | |||||
| onnx_weight_name = onnx_weight_inst[idx].name | onnx_weight_name = onnx_weight_inst[idx].name | ||||
| onnx_weight_shape = onnx_weight_inst[idx].value.shape | onnx_weight_shape = onnx_weight_inst[idx].value.shape | ||||
| @@ -41,8 +41,8 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||||
| moving_mean = BatchNormMapper._find_val_by_index(2, weights) | moving_mean = BatchNormMapper._find_val_by_index(2, weights) | ||||
| moving_variance = BatchNormMapper._find_val_by_index(3, weights) | moving_variance = BatchNormMapper._find_val_by_index(3, weights) | ||||
| return { | return { | ||||
| 'gamma': gamma, | |||||
| 'beta': beta, | |||||
| 'moving_mean': moving_mean, | |||||
| 'moving_variance': moving_variance | |||||
| 'gamma': {'data': gamma}, | |||||
| 'beta': {'data': beta}, | |||||
| 'moving_mean': {'data': moving_mean}, | |||||
| 'moving_variance': {'data': moving_variance} | |||||
| } | } | ||||
| @@ -143,8 +143,8 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| weight = ConvMapper._find_val_by_index(0, weights) | weight = ConvMapper._find_val_by_index(0, weights) | ||||
| bias = ConvMapper._find_val_by_index(1, weights) | bias = ConvMapper._find_val_by_index(1, weights) | ||||
| converted_weights = {'weight': weight} | |||||
| converted_weights = {'weight': {'data': weight}} | |||||
| if isinstance(bias, np.ndarray): | if isinstance(bias, np.ndarray): | ||||
| converted_weights['bias'] = bias | |||||
| converted_weights['bias'] = {'data': bias} | |||||
| return converted_weights | return converted_weights | ||||
| @@ -45,6 +45,6 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| weight = DenseMapper._find_val_by_index(0, weights) | weight = DenseMapper._find_val_by_index(0, weights) | ||||
| bias = DenseMapper._find_val_by_index(1, weights) | bias = DenseMapper._find_val_by_index(1, weights) | ||||
| return { | return { | ||||
| 'weight': weight, | |||||
| 'bias': bias | |||||
| 'weight': {'data': weight}, | |||||
| 'bias': {'data': bias} | |||||
| } | } | ||||
| @@ -14,7 +14,8 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||||
| WeightType | |||||
| class MatMulMapper(ONNXToMindSporeMapper): | class MatMulMapper(ONNXToMindSporeMapper): | ||||
| @@ -32,7 +33,7 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| weight = MatMulMapper._find_val_by_index(0, weights) | weight = MatMulMapper._find_val_by_index(0, weights) | ||||
| return {'weight': weight} | |||||
| return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}} | |||||
| @staticmethod | @staticmethod | ||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| @@ -41,6 +42,7 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| op = kwargs.get("operation") | op = kwargs.get("operation") | ||||
| args = kwargs.get("converted_params") | args = kwargs.get("converted_params") | ||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| trainable_params = kwargs.get('trainable_params', dict()) | |||||
| if not op: | if not op: | ||||
| raise ValueError("Can not get MindSpore operation name.") | raise ValueError("Can not get MindSpore operation name.") | ||||
| if not weights: | if not weights: | ||||
| @@ -53,7 +55,8 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| args["weight_shape"] = tensor.shape | args["weight_shape"] = tensor.shape | ||||
| args["weight_dtype"] = tensor.dtype | args["weight_dtype"] = tensor.dtype | ||||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | init_tensor = f"self.{{{variable_slot}}}_w = " \ | ||||
| f"Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}}))" | |||||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \ | |||||
| f"name=None)" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | ||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | ||||
| f"self.{{{variable_slot}}}_w)" | f"self.{{{variable_slot}}}_w)" | ||||
| @@ -72,7 +75,7 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||||
| } | } | ||||
| } | } | ||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | outputs_list = [f"opt_{{{variable_slot}}}"] | ||||
| @@ -13,7 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ | |||||
| WeightType | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| @@ -30,7 +31,9 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | |||||
| weights = kwargs['weights'] | |||||
| bias = AddMapper._find_val_by_index(0, weights) | |||||
| return {'bias': {'data': bias, 'type': WeightType.PARAMETER.value}} | |||||
| @staticmethod | @staticmethod | ||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| @@ -39,6 +42,7 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| op = kwargs.get("operation") | op = kwargs.get("operation") | ||||
| args = kwargs.get("converted_params") | args = kwargs.get("converted_params") | ||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| trainable_params = kwargs.get('trainable_params', dict()) | |||||
| if not op: | if not op: | ||||
| raise ValueError("Can not get MindSpore operation name.") | raise ValueError("Can not get MindSpore operation name.") | ||||
| if not weights: | if not weights: | ||||
| @@ -51,7 +55,8 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| args["bias_shape"] = tensor.shape | args["bias_shape"] = tensor.shape | ||||
| args["bias_dtype"] = tensor.dtype | args["bias_dtype"] = tensor.dtype | ||||
| init_tensor = f"self.{{{variable_slot}}}_bias = " \ | init_tensor = f"self.{{{variable_slot}}}_bias = " \ | ||||
| f"Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}}))" | |||||
| f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ | |||||
| f"name=None)" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | ||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | ||||
| f"self.{{{variable_slot}}}_bias)" | f"self.{{{variable_slot}}}_bias)" | ||||
| @@ -70,7 +75,7 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||||
| } | } | ||||
| } | } | ||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | outputs_list = [f"opt_{{{variable_slot}}}"] | ||||
| @@ -211,7 +211,7 @@ def _retrieve_operators(module_path, module_dict): | |||||
| """Lift nodes upper.""" | """Lift nodes upper.""" | ||||
| nonlocal added_module | nonlocal added_module | ||||
| lifted_submodule = [] | lifted_submodule = [] | ||||
| continuity_idx = -1 | |||||
| record = dict() | |||||
| lift_needed = _whether_to_lift(sub_module) | lift_needed = _whether_to_lift(sub_module) | ||||
| for m in sub_module: | for m in sub_module: | ||||
| scopes = m.split("/") | scopes = m.split("/") | ||||
| @@ -219,8 +219,8 @@ def _retrieve_operators(module_path, module_dict): | |||||
| # If the scope depth is 3, like ModuleX/ModuleY/Gemm, | # If the scope depth is 3, like ModuleX/ModuleY/Gemm, | ||||
| # then we lift ModuleY to top level. | # then we lift ModuleY to top level. | ||||
| md_name, md_idx = scopes[-2].split("_") | md_name, md_idx = scopes[-2].split("_") | ||||
| if continuity_idx != int(md_idx): | |||||
| continuity_idx = int(md_idx) | |||||
| if record.get(md_name, -1) != md_idx: | |||||
| record[md_name] = md_idx | |||||
| added_module[md_name] = added_module.setdefault(md_name, -1) + 1 | added_module[md_name] = added_module.setdefault(md_name, -1) + 1 | ||||
| lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") | lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") | ||||
| continue | continue | ||||