Browse Source

!1152 Fix mappers for weight transformer in MindConverter.

From: @moran3
Reviewed-by: @liuchongming74,@yelihua,@yelihua
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
cdf5fe8236
9 changed files with 49 additions and 29 deletions
  1. +3
    -3
      mindinsight/mindconverter/cli.py
  2. +8
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +11
    -6
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  4. +4
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  5. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  6. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  7. +7
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  8. +9
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  9. +3
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 3
- 3
mindinsight/mindconverter/cli.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -318,8 +318,8 @@ parser.add_argument(
action=ModelFileAction, action=ModelFileAction,
required=False, required=False,
help=""" help="""
PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When
PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path
is expected to do script generation based on graph schema. When
`--in_file` and `--model_file` are both provided, `--in_file` and `--model_file` are both provided,
use AST schema as default. use AST schema as default.
""") """)


+ 8
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -20,6 +20,7 @@ SEPARATOR_IN_SCOPE = "/"
SEPARATOR_BTW_NAME_AND_ID = "_" SEPARATOR_BTW_NAME_AND_ID = "_"
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
LINK_IN_SCOPE = "-" LINK_IN_SCOPE = "-"
LINK_IN_WEIGHT_NAME = "."
LEFT_BUCKET = "[" LEFT_BUCKET = "["
RIGHT_BUCKET = "]" RIGHT_BUCKET = "]"


@@ -127,6 +128,12 @@ class FrameworkType(Enum):
UNKNOWN = 2 UNKNOWN = 2




@unique
class WeightType(Enum):
PARAMETER = 0
COMMON = 1


def get_imported_module(): def get_imported_module():
""" """
Generate imported module header. Generate imported module header.
@@ -137,5 +144,5 @@ def get_imported_module():
return f"import numpy as np{NEW_LINE}" \ return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \ f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \ f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore import Tensor, Parameter{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}" f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 11
- 6
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -29,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list


@@ -499,12 +499,17 @@ class Generator:
if node_inst.fragment.exchange_msg['var_0']['trainable_params']: if node_inst.fragment.exchange_msg['var_0']['trainable_params']:
weights_scope_name = self.generate_weight_scope_name(node_name) weights_scope_name = self.generate_weight_scope_name(node_name)
onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights']
for idx, (weight_key, weight_value) in \
for idx, (weight_key, weight_value_object) in \
enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()):
weight_name = '.'.join((weights_scope_name, weight_key))
weight_shape = Tensor(weight_value).shape
data_type = Tensor(weight_value).dtype
trainable_weights_dict[weight_name] = weight_value
value_type = weight_value_object.get('type', WeightType.COMMON.value)
value_data = weight_value_object['data']
if value_type == WeightType.PARAMETER.value:
weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key))
else:
weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key))
weight_shape = Tensor(value_data).shape
data_type = Tensor(value_data).dtype
trainable_weights_dict[weight_name] = value_data


onnx_weight_name = onnx_weight_inst[idx].name onnx_weight_name = onnx_weight_inst[idx].name
onnx_weight_shape = onnx_weight_inst[idx].value.shape onnx_weight_shape = onnx_weight_inst[idx].value.shape


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -41,8 +41,8 @@ class BatchNormMapper(ONNXToMindSporeMapper):
moving_mean = BatchNormMapper._find_val_by_index(2, weights) moving_mean = BatchNormMapper._find_val_by_index(2, weights)
moving_variance = BatchNormMapper._find_val_by_index(3, weights) moving_variance = BatchNormMapper._find_val_by_index(3, weights)
return { return {
'gamma': gamma,
'beta': beta,
'moving_mean': moving_mean,
'moving_variance': moving_variance
'gamma': {'data': gamma},
'beta': {'data': beta},
'moving_mean': {'data': moving_mean},
'moving_variance': {'data': moving_variance}
} }

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -143,8 +143,8 @@ class ConvMapper(ONNXToMindSporeMapper):
weight = ConvMapper._find_val_by_index(0, weights) weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights) bias = ConvMapper._find_val_by_index(1, weights)


converted_weights = {'weight': weight}
converted_weights = {'weight': {'data': weight}}
if isinstance(bias, np.ndarray): if isinstance(bias, np.ndarray):
converted_weights['bias'] = bias
converted_weights['bias'] = {'data': bias}


return converted_weights return converted_weights

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -45,6 +45,6 @@ class DenseMapper(ONNXToMindSporeMapper):
weight = DenseMapper._find_val_by_index(0, weights) weight = DenseMapper._find_val_by_index(0, weights)
bias = DenseMapper._find_val_by_index(1, weights) bias = DenseMapper._find_val_by_index(1, weights)
return { return {
'weight': weight,
'bias': bias
'weight': {'data': weight},
'bias': {'data': bias}
} }

+ 7
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -14,7 +14,8 @@
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \
WeightType




class MatMulMapper(ONNXToMindSporeMapper): class MatMulMapper(ONNXToMindSporeMapper):
@@ -32,7 +33,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
weights = kwargs['weights'] weights = kwargs['weights']
weight = MatMulMapper._find_val_by_index(0, weights) weight = MatMulMapper._find_val_by_index(0, weights)
return {'weight': weight}
return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}}


@staticmethod @staticmethod
def _generate_snippet_template(**kwargs): def _generate_snippet_template(**kwargs):
@@ -41,6 +42,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
op = kwargs.get("operation") op = kwargs.get("operation")
args = kwargs.get("converted_params") args = kwargs.get("converted_params")
weights = kwargs.get("weights") weights = kwargs.get("weights")
trainable_params = kwargs.get('trainable_params', dict())
if not op: if not op:
raise ValueError("Can not get MindSpore operation name.") raise ValueError("Can not get MindSpore operation name.")
if not weights: if not weights:
@@ -53,7 +55,8 @@ class MatMulMapper(ONNXToMindSporeMapper):
args["weight_shape"] = tensor.shape args["weight_shape"] = tensor.shape
args["weight_dtype"] = tensor.dtype args["weight_dtype"] = tensor.dtype
init_tensor = f"self.{{{variable_slot}}}_w = " \ init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}}))"
f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \
f"name=None)"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w)" f"self.{{{variable_slot}}}_w)"
@@ -72,7 +75,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
} }
} }
outputs_list = [f"opt_{{{variable_slot}}}"] outputs_list = [f"opt_{{{variable_slot}}}"]


+ 9
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \
WeightType
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper




@@ -30,7 +31,9 @@ class AddMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
bias = AddMapper._find_val_by_index(0, weights)
return {'bias': {'data': bias, 'type': WeightType.PARAMETER.value}}


@staticmethod @staticmethod
def _generate_snippet_template(**kwargs): def _generate_snippet_template(**kwargs):
@@ -39,6 +42,7 @@ class AddMapper(ONNXToMindSporeMapper):
op = kwargs.get("operation") op = kwargs.get("operation")
args = kwargs.get("converted_params") args = kwargs.get("converted_params")
weights = kwargs.get("weights") weights = kwargs.get("weights")
trainable_params = kwargs.get('trainable_params', dict())
if not op: if not op:
raise ValueError("Can not get MindSpore operation name.") raise ValueError("Can not get MindSpore operation name.")
if not weights: if not weights:
@@ -51,7 +55,8 @@ class AddMapper(ONNXToMindSporeMapper):
args["bias_shape"] = tensor.shape args["bias_shape"] = tensor.shape
args["bias_dtype"] = tensor.dtype args["bias_dtype"] = tensor.dtype
init_tensor = f"self.{{{variable_slot}}}_bias = " \ init_tensor = f"self.{{{variable_slot}}}_bias = " \
f"Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}}))"
f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \
f"name=None)"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_bias)" f"self.{{{variable_slot}}}_bias)"
@@ -70,7 +75,7 @@ class AddMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
} }
} }
outputs_list = [f"opt_{{{variable_slot}}}"] outputs_list = [f"opt_{{{variable_slot}}}"]


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -211,7 +211,7 @@ def _retrieve_operators(module_path, module_dict):
"""Lift nodes upper.""" """Lift nodes upper."""
nonlocal added_module nonlocal added_module
lifted_submodule = [] lifted_submodule = []
continuity_idx = -1
record = dict()
lift_needed = _whether_to_lift(sub_module) lift_needed = _whether_to_lift(sub_module)
for m in sub_module: for m in sub_module:
scopes = m.split("/") scopes = m.split("/")
@@ -219,8 +219,8 @@ def _retrieve_operators(module_path, module_dict):
# If the scope depth is 3, like ModuleX/ModuleY/Gemm, # If the scope depth is 3, like ModuleX/ModuleY/Gemm,
# then we lift ModuleY to top level. # then we lift ModuleY to top level.
md_name, md_idx = scopes[-2].split("_") md_name, md_idx = scopes[-2].split("_")
if continuity_idx != int(md_idx):
continuity_idx = int(md_idx)
if record.get(md_name, -1) != md_idx:
record[md_name] = md_idx
added_module[md_name] = added_module.setdefault(md_name, -1) + 1 added_module[md_name] = added_module.setdefault(md_name, -1) + 1
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}") lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}")
continue continue


Loading…
Cancel
Save