Browse Source

!1320 [MindConverter] Fix shared weight for mixture between tensor and non-tensor

From: @moran3
Reviewed-by: @liuchongming74,@ouwenchang
Signed-off-by: @ouwenchang
pull/1320/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
2530053d15
1 changed files with 35 additions and 14 deletions
  1. +35
    -14
      mindinsight/mindconverter/graph_based_converter/generator/generator.py

+ 35
- 14
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -30,13 +30,15 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME, \
ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list
from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher
from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper
from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE


class CodeStruct:
"""
Define the Code template for each module generated in the final output.
@@ -112,7 +114,7 @@ class CodeStruct:
init_lines += init_str
cons_lines += cons_str

else: # is ModuleStruct
else: # is ModuleStruct
# check if this instance generated CodeStruct
if GlobalContext().code_structs.get(struct.pattern_id) is None:
CodeStruct(struct, repeated_submodules)
@@ -126,7 +128,7 @@ class CodeStruct:
self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):"
self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()"

#add shared weights declaration in init code part
# add shared weights declaration in init code part
if md_struct.identifier == []:
passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct)
for s in passthrough_w_declaration:
@@ -259,14 +261,34 @@ class Generator:
@staticmethod
def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list):
"""Set the weight with given param postfix to args translation."""

args_name = ExchangeMessageKeywords.VariableScope.value.ARGS.value
parameters_name = ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value
trainable_params_name = ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value

has_non_tensor = not np.alltrue(
[bool(nd_struct.fragment.default_var.get(parameters_name)) for _, nd_struct in nd_struct_list])
for _, nd_struct in nd_struct_list:
nparr = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix).get('data')
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_shape"] = nparr.shape
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_dtype"] = nparr.dtype
init_tensor_template = f"Parameter(Tensor(np.random.uniform(0, 1, "\
f"{{{t_param_postfix}_shape}}).astype(np.{{{t_param_postfix}_dtype}})), "\
f"name=None)"
nd_struct.fragment.default_var["parameters"][t_param_postfix] = init_tensor_template
if has_non_tensor:
parameters = nd_struct.fragment.default_var.get(parameters_name)
if parameters:
nparr = nd_struct.fragment.default_var[trainable_params_name].get(t_param_postfix).get('data')
init_tensor_template = nd_struct.fragment.fragment.create_parameter(nparr.shape, nparr.dtype)
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}"] = init_tensor_template
nd_struct.fragment.default_var[parameters_name][t_param_postfix] = f"{{var_0}}_{t_param_postfix}"
else:
value_name = f"{t_param_postfix}_value"
init_tensor_template = f"{nd_struct.fragment.default_var[args_name][value_name]}"
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}"] = init_tensor_template
nd_struct.fragment.default_var[parameters_name] = {t_param_postfix: f"{{var_0}}_{t_param_postfix}"}
del nd_struct.fragment.default_var[args_name][value_name]
else:
nparr = nd_struct.fragment.default_var[trainable_params_name].get(t_param_postfix).get('data')
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}_shape"] = nparr.shape
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}_dtype"] = nparr.dtype
init_tensor_template = nd_struct.fragment.fragment.create_parameter(f"{{{t_param_postfix}_shape}}",
f"{{{t_param_postfix}_dtype}}")
nd_struct.fragment.default_var[parameters_name][t_param_postfix] = init_tensor_template

def _get_same_trainable_params_onnx_name_from_repeated_nodes(self,
t_param_postfix,
@@ -279,8 +301,8 @@ class Generator:
for (_, nd_struct) in nd_struct_list[1:]:
compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix)
if not compared_t_param_data_dict:
raise ValueError(f"Inconsistent trainable params detected for node "\
f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}")
raise ValueError(f"Inconsistent trainable params detected for node " \
f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}")
compared_t_name = compared_t_param_data_dict.get('onnx_name')
t_onnx_names.append(compared_t_name)
return t_onnx_names
@@ -301,7 +323,7 @@ class Generator:
if base_nd_struct.fragment.default_var.get("parameters"):
# set only if has parameters as it requires rewritten.
for (t_param_postfix, t_param_data_dict) in \
base_nd_struct.fragment.default_var["trainable_params"].items():
base_nd_struct.fragment.default_var["trainable_params"].items():
if not isinstance(t_param_data_dict.get('data'), np.ndarray):
continue
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list)
@@ -332,7 +354,6 @@ class Generator:
continue
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list)


def _list_formal_parameters_in_a_module(self, module_filter_return):
"""
Find all formal args / params from nodes in a module.


Loading…
Cancel
Save