Browse Source

!1300 add reduce_sum mapper

From: @shenghong96
Reviewed-by: @yelihua,@lilongfei15
Signed-off-by: @lilongfei15
pull/1300/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
fcb4281917
3 changed files with 88 additions and 2 deletions
  1. +10
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  2. +76
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py
  3. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json

+ 10
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -15,13 +15,22 @@
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper

# The segmentation of BatchNorm's dim attribute.
DIM_SEGMENT = 2


class BatchNormMapper(ONNXToMindSporeMapper):
"""BatchNorm mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
dim = len(kwargs['params']['output_shape']) - 2
output_shape = len(kwargs['params']['output_shape'])
if output_shape == DIM_SEGMENT:
dim = 1
elif output_shape > DIM_SEGMENT:
dim = len(kwargs['params']['output_shape']) - 2
else:
return None
return f"nn.BatchNorm{dim}d"

@staticmethod


+ 76
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py View File

@@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class ReduceSumMapper(ONNXToMindSporeMapper):
"""ReduceSum mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.ReduceSum"

@staticmethod
def _convert_params(**kwargs):
params = kwargs['params']
keep_dims = not params.get('keepdims', 1) == 0
return {'keep_dims': keep_dims}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
raw_params = kwargs.get("raw_params")
if raw_params.get('axes'):
axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes'])
else:
axis = tuple()
variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"

args["axis"] = axis
init_tensor = f"self.{{{variable_slot}}}_axis = {{axis}}"

construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"self.{{{variable_slot}}}_axis)"

template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, init_tensor],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: [],
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -40,5 +40,6 @@
"onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper",
"onnx::CumSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cumsum_mapper.CumSumMapper",
"onnx::Sin": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.sin_mapper.SinMapper",
"onnx::Cos": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cos_mapper.CosMapper"
"onnx::Cos": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cos_mapper.CosMapper",
"onnx::ReduceSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_sum_mapper.ReduceSumMapper"
}

Loading…
Cancel
Save