|
|
@@ -30,13 +30,15 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO |
|
|
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config |
|
|
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config |
|
|
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr |
|
|
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr |
|
|
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ |
|
|
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ |
|
|
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME |
|
|
|
|
|
|
|
|
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME, \ |
|
|
|
|
|
ExchangeMessageKeywords |
|
|
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator |
|
|
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator |
|
|
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list |
|
|
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list |
|
|
from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher |
|
|
from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher |
|
|
from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper |
|
|
from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper |
|
|
from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE |
|
|
from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CodeStruct: |
|
|
class CodeStruct: |
|
|
""" |
|
|
""" |
|
|
Define the Code template for each module generated in the final output. |
|
|
Define the Code template for each module generated in the final output. |
|
|
@@ -112,7 +114,7 @@ class CodeStruct: |
|
|
init_lines += init_str |
|
|
init_lines += init_str |
|
|
cons_lines += cons_str |
|
|
cons_lines += cons_str |
|
|
|
|
|
|
|
|
else: # is ModuleStruct |
|
|
|
|
|
|
|
|
else: # is ModuleStruct |
|
|
# check if this instance generated CodeStruct |
|
|
# check if this instance generated CodeStruct |
|
|
if GlobalContext().code_structs.get(struct.pattern_id) is None: |
|
|
if GlobalContext().code_structs.get(struct.pattern_id) is None: |
|
|
CodeStruct(struct, repeated_submodules) |
|
|
CodeStruct(struct, repeated_submodules) |
|
|
@@ -126,7 +128,7 @@ class CodeStruct: |
|
|
self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" |
|
|
self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" |
|
|
self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" |
|
|
self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" |
|
|
|
|
|
|
|
|
#add shared weights declaration in init code part |
|
|
|
|
|
|
|
|
# add shared weights declaration in init code part |
|
|
if md_struct.identifier == []: |
|
|
if md_struct.identifier == []: |
|
|
passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct) |
|
|
passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct) |
|
|
for s in passthrough_w_declaration: |
|
|
for s in passthrough_w_declaration: |
|
|
@@ -259,14 +261,34 @@ class Generator: |
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list): |
|
|
def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list): |
|
|
"""Set the weight with given param postfix to args translation.""" |
|
|
"""Set the weight with given param postfix to args translation.""" |
|
|
|
|
|
|
|
|
|
|
|
args_name = ExchangeMessageKeywords.VariableScope.value.ARGS.value |
|
|
|
|
|
parameters_name = ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value |
|
|
|
|
|
trainable_params_name = ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value |
|
|
|
|
|
|
|
|
|
|
|
has_non_tensor = not np.alltrue( |
|
|
|
|
|
[bool(nd_struct.fragment.default_var.get(parameters_name)) for _, nd_struct in nd_struct_list]) |
|
|
for _, nd_struct in nd_struct_list: |
|
|
for _, nd_struct in nd_struct_list: |
|
|
nparr = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix).get('data') |
|
|
|
|
|
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_shape"] = nparr.shape |
|
|
|
|
|
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_dtype"] = nparr.dtype |
|
|
|
|
|
init_tensor_template = f"Parameter(Tensor(np.random.uniform(0, 1, "\ |
|
|
|
|
|
f"{{{t_param_postfix}_shape}}).astype(np.{{{t_param_postfix}_dtype}})), "\ |
|
|
|
|
|
f"name=None)" |
|
|
|
|
|
nd_struct.fragment.default_var["parameters"][t_param_postfix] = init_tensor_template |
|
|
|
|
|
|
|
|
if has_non_tensor: |
|
|
|
|
|
parameters = nd_struct.fragment.default_var.get(parameters_name) |
|
|
|
|
|
if parameters: |
|
|
|
|
|
nparr = nd_struct.fragment.default_var[trainable_params_name].get(t_param_postfix).get('data') |
|
|
|
|
|
init_tensor_template = nd_struct.fragment.fragment.create_parameter(nparr.shape, nparr.dtype) |
|
|
|
|
|
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}"] = init_tensor_template |
|
|
|
|
|
nd_struct.fragment.default_var[parameters_name][t_param_postfix] = f"{{var_0}}_{t_param_postfix}" |
|
|
|
|
|
else: |
|
|
|
|
|
value_name = f"{t_param_postfix}_value" |
|
|
|
|
|
init_tensor_template = f"{nd_struct.fragment.default_var[args_name][value_name]}" |
|
|
|
|
|
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}"] = init_tensor_template |
|
|
|
|
|
nd_struct.fragment.default_var[parameters_name] = {t_param_postfix: f"{{var_0}}_{t_param_postfix}"} |
|
|
|
|
|
del nd_struct.fragment.default_var[args_name][value_name] |
|
|
|
|
|
else: |
|
|
|
|
|
nparr = nd_struct.fragment.default_var[trainable_params_name].get(t_param_postfix).get('data') |
|
|
|
|
|
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}_shape"] = nparr.shape |
|
|
|
|
|
nd_struct.fragment.default_var[args_name][f"{t_param_postfix}_dtype"] = nparr.dtype |
|
|
|
|
|
init_tensor_template = nd_struct.fragment.fragment.create_parameter(f"{{{t_param_postfix}_shape}}", |
|
|
|
|
|
f"{{{t_param_postfix}_dtype}}") |
|
|
|
|
|
nd_struct.fragment.default_var[parameters_name][t_param_postfix] = init_tensor_template |
|
|
|
|
|
|
|
|
def _get_same_trainable_params_onnx_name_from_repeated_nodes(self, |
|
|
def _get_same_trainable_params_onnx_name_from_repeated_nodes(self, |
|
|
t_param_postfix, |
|
|
t_param_postfix, |
|
|
@@ -279,8 +301,8 @@ class Generator: |
|
|
for (_, nd_struct) in nd_struct_list[1:]: |
|
|
for (_, nd_struct) in nd_struct_list[1:]: |
|
|
compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix) |
|
|
compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix) |
|
|
if not compared_t_param_data_dict: |
|
|
if not compared_t_param_data_dict: |
|
|
raise ValueError(f"Inconsistent trainable params detected for node "\ |
|
|
|
|
|
f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}") |
|
|
|
|
|
|
|
|
raise ValueError(f"Inconsistent trainable params detected for node " \ |
|
|
|
|
|
f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}") |
|
|
compared_t_name = compared_t_param_data_dict.get('onnx_name') |
|
|
compared_t_name = compared_t_param_data_dict.get('onnx_name') |
|
|
t_onnx_names.append(compared_t_name) |
|
|
t_onnx_names.append(compared_t_name) |
|
|
return t_onnx_names |
|
|
return t_onnx_names |
|
|
@@ -301,7 +323,7 @@ class Generator: |
|
|
if base_nd_struct.fragment.default_var.get("parameters"): |
|
|
if base_nd_struct.fragment.default_var.get("parameters"): |
|
|
# set only if has parameters as it requires rewritten. |
|
|
# set only if has parameters as it requires rewritten. |
|
|
for (t_param_postfix, t_param_data_dict) in \ |
|
|
for (t_param_postfix, t_param_data_dict) in \ |
|
|
base_nd_struct.fragment.default_var["trainable_params"].items(): |
|
|
|
|
|
|
|
|
base_nd_struct.fragment.default_var["trainable_params"].items(): |
|
|
if not isinstance(t_param_data_dict.get('data'), np.ndarray): |
|
|
if not isinstance(t_param_data_dict.get('data'), np.ndarray): |
|
|
continue |
|
|
continue |
|
|
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) |
|
|
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) |
|
|
@@ -332,7 +354,6 @@ class Generator: |
|
|
continue |
|
|
continue |
|
|
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) |
|
|
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _list_formal_parameters_in_a_module(self, module_filter_return): |
|
|
def _list_formal_parameters_in_a_module(self, module_filter_return): |
|
|
""" |
|
|
""" |
|
|
Find all formal args / params from nodes in a module. |
|
|
Find all formal args / params from nodes in a module. |
|
|
|