Browse Source

Fix mappers for weights transformer

tags/v1.2.0-rc1
moran 4 years ago
parent
commit
03d159ec0a
9 changed files with 49 additions and 29 deletions
  1. +3
    -3
      mindinsight/mindconverter/cli.py
  2. +8
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +11
    -6
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  4. +4
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  5. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  6. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  7. +7
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  8. +9
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  9. +3
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 3
- 3
mindinsight/mindconverter/cli.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -318,8 +318,8 @@ parser.add_argument(
action=ModelFileAction,
required=False,
help="""
PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When
PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path
is expected to do script generation based on graph schema. When
`--in_file` and `--model_file` are both provided,
use AST schema as default.
""")


+ 8
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -20,6 +20,7 @@ SEPARATOR_IN_SCOPE = "/"
SEPARATOR_BTW_NAME_AND_ID = "_"
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
LINK_IN_SCOPE = "-"
LINK_IN_WEIGHT_NAME = "."
LEFT_BUCKET = "["
RIGHT_BUCKET = "]"

@@ -127,6 +128,12 @@ class FrameworkType(Enum):
UNKNOWN = 2


@unique
class WeightType(Enum):
PARAMETER = 0
COMMON = 1


def get_imported_module():
"""
Generate imported module header.
@@ -137,5 +144,5 @@ def get_imported_module():
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore import Tensor, Parameter{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 11
- 6
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -29,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list

@@ -502,12 +502,17 @@ class Generator:
if node_inst.fragment.exchange_msg['var_0']['trainable_params']:
weights_scope_name = self.generate_weight_scope_name(node_name)
onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights']
for idx, (weight_key, weight_value) in \
for idx, (weight_key, weight_value_object) in \
enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()):
weight_name = '.'.join((weights_scope_name, weight_key))
weight_shape = Tensor(weight_value).shape
data_type = Tensor(weight_value).dtype
trainable_weights_dict[weight_name] = weight_value
value_type = weight_value_object.get('type', WeightType.COMMON.value)
value_data = weight_value_object['data']
if value_type == WeightType.PARAMETER.value:
weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key))
else:
weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key))
weight_shape = Tensor(value_data).shape
data_type = Tensor(value_data).dtype
trainable_weights_dict[weight_name] = value_data

onnx_weight_name = onnx_weight_inst[idx].name
onnx_weight_shape = onnx_weight_inst[idx].value.shape


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -41,8 +41,8 @@ class BatchNormMapper(ONNXToMindSporeMapper):
moving_mean = BatchNormMapper._find_val_by_index(2, weights)
moving_variance = BatchNormMapper._find_val_by_index(3, weights)
return {
'gamma': gamma,
'beta': beta,
'moving_mean': moving_mean,
'moving_variance': moving_variance
'gamma': {'data': gamma},
'beta': {'data': beta},
'moving_mean': {'data': moving_mean},
'moving_variance': {'data': moving_variance}
}

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -143,8 +143,8 @@ class ConvMapper(ONNXToMindSporeMapper):
weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights)

converted_weights = {'weight': weight}
converted_weights = {'weight': {'data': weight}}
if isinstance(bias, np.ndarray):
converted_weights['bias'] = bias
converted_weights['bias'] = {'data': bias}

return converted_weights

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -45,6 +45,6 @@ class DenseMapper(ONNXToMindSporeMapper):
weight = DenseMapper._find_val_by_index(0, weights)
bias = DenseMapper._find_val_by_index(1, weights)
return {
'weight': weight,
'bias': bias
'weight': {'data': weight},
'bias': {'data': bias}
}

+ 7
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -14,7 +14,8 @@
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \
WeightType


class MatMulMapper(ONNXToMindSporeMapper):
@@ -32,7 +33,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs):
weights = kwargs['weights']
weight = MatMulMapper._find_val_by_index(0, weights)
return {'weight': weight}
return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}}

@staticmethod
def _generate_snippet_template(**kwargs):
@@ -41,6 +42,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
trainable_params = kwargs.get('trainable_params', dict())
if not op:
raise ValueError("Can not get MindSpore operation name.")
if not weights:
@@ -53,7 +55,8 @@ class MatMulMapper(ONNXToMindSporeMapper):
args["weight_shape"] = tensor.shape
args["weight_dtype"] = tensor.dtype
init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}}))"
f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \
f"name=None)"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w)"
@@ -72,7 +75,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]


+ 9
- 4
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -13,7 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \
WeightType
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


@@ -30,7 +31,9 @@ class AddMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
bias = AddMapper._find_val_by_index(0, weights)
return {'bias': {'data': bias, 'type': WeightType.PARAMETER.value}}

@staticmethod
def _generate_snippet_template(**kwargs):
@@ -39,6 +42,7 @@ class AddMapper(ONNXToMindSporeMapper):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
trainable_params = kwargs.get('trainable_params', dict())
if not op:
raise ValueError("Can not get MindSpore operation name.")
if not weights:
@@ -51,7 +55,8 @@ class AddMapper(ONNXToMindSporeMapper):
args["bias_shape"] = tensor.shape
args["bias_dtype"] = tensor.dtype
init_tensor = f"self.{{{variable_slot}}}_bias = " \
f"Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}}))"
f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \
f"name=None)"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_bias)"
@@ -70,7 +75,7 @@ class AddMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -211,7 +211,7 @@ def _retrieve_operators(module_path, module_dict):
"""Lift nodes upper."""
nonlocal added_module
lifted_submodule = []
continuity_idx = -1
record = dict()
lift_needed = _whether_to_lift(sub_module)
for m in sub_module:
scopes = m.split("/")
@@ -219,8 +219,8 @@ def _retrieve_operators(module_path, module_dict):
# If the scope depth is 3, like ModuleX/ModuleY/Gemm,
# then we lift ModuleY to top level.
md_name, md_idx = scopes[-2].split("_")
if continuity_idx != int(md_idx):
continuity_idx = int(md_idx)
if record.get(md_name, -1) != md_idx:
record[md_name] = md_idx
added_module[md_name] = added_module.setdefault(md_name, -1) + 1
lifted_submodule.append(f"{md_name}_{added_module.setdefault(md_name, 0)}/{scopes[-1]}")
continue


Loading…
Cancel
Save