diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 44a1e07f..68b59464 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -318,8 +318,8 @@ parser.add_argument( action=ModelFileAction, required=False, help=""" - PyTorch .pth or Tensorflow .pb model file path to use graph - based schema to do script generation. When + PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path + is expected to do script generation based on graph schema. When `--in_file` and `--model_file` are both provided, use AST schema as default. """) diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 330ef767..71db7f08 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -20,6 +20,7 @@ SEPARATOR_IN_SCOPE = "/" SEPARATOR_BTW_NAME_AND_ID = "_" SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" LINK_IN_SCOPE = "-" +LINK_IN_WEIGHT_NAME = "." LEFT_BUCKET = "[" RIGHT_BUCKET = "]" @@ -127,6 +128,12 @@ class FrameworkType(Enum): UNKNOWN = 2 +@unique +class WeightType(Enum): + PARAMETER = 0 + COMMON = 1 + + def get_imported_module(): """ Generate imported module header. @@ -137,5 +144,5 @@ def get_imported_module(): return f"import numpy as np{NEW_LINE}" \ f"import mindspore{NEW_LINE}" \ f"from mindspore import nn{NEW_LINE}" \ - f"from mindspore import Tensor{NEW_LINE}" \ + f"from mindspore import Tensor, Parameter{NEW_LINE}" \ f"from mindspore.ops import operations as P{NEW_LINE * 3}" diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index 049b0103..4ad9346f 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -29,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ - FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID + FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list @@ -499,12 +499,17 @@ class Generator: if node_inst.fragment.exchange_msg['var_0']['trainable_params']: weights_scope_name = self.generate_weight_scope_name(node_name) onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] - for idx, (weight_key, weight_value) in \ + for idx, (weight_key, weight_value_object) in \ enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): - weight_name = '.'.join((weights_scope_name, weight_key)) - weight_shape = Tensor(weight_value).shape - data_type = Tensor(weight_value).dtype - trainable_weights_dict[weight_name] = weight_value + value_type = weight_value_object.get('type', WeightType.COMMON.value) + value_data = weight_value_object['data'] + if value_type == WeightType.PARAMETER.value: + weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key)) + else: + weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key)) + weight_shape = Tensor(value_data).shape + data_type = Tensor(value_data).dtype + trainable_weights_dict[weight_name] = value_data onnx_weight_name = onnx_weight_inst[idx].name onnx_weight_shape = onnx_weight_inst[idx].value.shape diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index 5a4591df..eee25de4 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -41,8 +41,8 @@ class BatchNormMapper(ONNXToMindSporeMapper): moving_mean = BatchNormMapper._find_val_by_index(2, weights) moving_variance = BatchNormMapper._find_val_by_index(3, weights) return { - 'gamma': gamma, - 'beta': beta, - 'moving_mean': moving_mean, - 'moving_variance': moving_variance + 'gamma': {'data': gamma}, + 'beta': {'data': beta}, + 'moving_mean': {'data': moving_mean}, + 'moving_variance': {'data': moving_variance} } diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index cdad6167..86668583 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -143,8 +143,8 @@ class ConvMapper(ONNXToMindSporeMapper): weight = ConvMapper._find_val_by_index(0, weights) bias = ConvMapper._find_val_by_index(1, weights) - converted_weights = {'weight': weight} + converted_weights = {'weight': {'data': weight}} if isinstance(bias, np.ndarray): - converted_weights['bias'] = bias + converted_weights['bias'] = {'data': bias} return converted_weights diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index f65f5a06..df70005b 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -45,6 +45,6 @@ class DenseMapper(ONNXToMindSporeMapper): weight = DenseMapper._find_val_by_index(0, weights) bias = DenseMapper._find_val_by_index(1, weights) return { - 'weight': weight, - 'bias': bias + 'weight': {'data': weight}, + 'bias': {'data': bias} } diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index 805fe9d7..f8467a48 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -14,7 +14,8 @@ # ============================================================================== """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ + WeightType class MatMulMapper(ONNXToMindSporeMapper): @@ -32,7 +33,7 @@ class MatMulMapper(ONNXToMindSporeMapper): def _convert_trained_weights(**kwargs): weights = kwargs['weights'] weight = MatMulMapper._find_val_by_index(0, weights) - return {'weight': weight} + return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}} @staticmethod def _generate_snippet_template(**kwargs): @@ -41,6 +42,7 @@ class MatMulMapper(ONNXToMindSporeMapper): op = kwargs.get("operation") args = kwargs.get("converted_params") weights = kwargs.get("weights") + trainable_params = kwargs.get('trainable_params', dict()) if not op: raise ValueError("Can not get MindSpore operation name.") if not weights: @@ -53,7 +55,8 @@ class MatMulMapper(ONNXToMindSporeMapper): args["weight_shape"] = tensor.shape args["weight_dtype"] = tensor.dtype init_tensor = f"self.{{{variable_slot}}}_w = " \ - f"Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}}))" + f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \ + f"name=None)" construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ f"self.{{{variable_slot}}}_w)" @@ -72,7 +75,7 @@ class MatMulMapper(ONNXToMindSporeMapper): ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, - ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params } } outputs_list = [f"opt_{{{variable_slot}}}"] diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py index db2413d1..ba44de91 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \ + WeightType from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper @@ -30,7 +31,9 @@ class AddMapper(ONNXToMindSporeMapper): @staticmethod def _convert_trained_weights(**kwargs): - return dict() + weights = kwargs['weights'] + bias = AddMapper._find_val_by_index(0, weights) + return {'bias': {'data': bias, 'type': WeightType.PARAMETER.value}} @staticmethod def _generate_snippet_template(**kwargs): @@ -39,6 +42,7 @@ class AddMapper(ONNXToMindSporeMapper): op = kwargs.get("operation") args = kwargs.get("converted_params") weights = kwargs.get("weights") + trainable_params = kwargs.get('trainable_params', dict()) if not op: raise ValueError("Can not get MindSpore operation name.") if not weights: @@ -51,7 +55,8 @@ class AddMapper(ONNXToMindSporeMapper): args["bias_shape"] = tensor.shape args["bias_dtype"] = tensor.dtype init_tensor = f"self.{{{variable_slot}}}_bias = " \ - f"Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}}))" + f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \ + f"name=None)" construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ f"self.{{{variable_slot}}}_bias)" @@ -70,7 +75,7 @@ class AddMapper(ONNXToMindSporeMapper): ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, - ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params } } outputs_list = [f"opt_{{{variable_slot}}}"] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index ea8ff3c3..c370ada0 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -211,7 +211,7 @@ def _retrieve_operators(module_path, module_dict): """Lift nodes upper.""" nonlocal added_module lifted_submodule = [] - continuity_idx = -1 + record = dict() lift_needed = _whether_to_lift(sub_module) for m in sub_module: scopes = m.split("/") @@ -219,8 +219,8 @@ def _retrieve_operators(module_path, module_dict): # If the scope depth is 3, like ModuleX/ModuleY/Gemm, # then we lift ModuleY to top level. md_name, md_idx = scopes[-2].split("_") - if continuity_idx != int(md_idx): - continuity_idx = int(md_idx) + if record.get(md_name, -1) != md_idx: + record[md_name] = md_idx added_module[md_name] = added_module.setdefault(md_name, -1) + 1 lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") continue