From a66fc125cc2a22260aad8fe0e9d61edc3f3e6e14 Mon Sep 17 00:00:00 2001 From: shenghong96 Date: Tue, 20 Apr 2021 10:05:11 +0800 Subject: [PATCH] add reduce_sum mapper --- .../mapper/impl/nn/batch_norm_mapper.py | 11 ++- .../mapper/impl/ops/reduce_sum_mapper.py | 76 +++++++++++++++++++ .../mapper/onnx_to_ms.json | 3 +- 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index eee25de4..d9917854 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -15,13 +15,22 @@ """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +# The segmentation of BatchNorm's dim attribute. +DIM_SEGMENT = 2 + class BatchNormMapper(ONNXToMindSporeMapper): """BatchNorm mapper.""" @staticmethod def _operation_name_in_ms(*args, **kwargs): - dim = len(kwargs['params']['output_shape']) - 2 + output_shape = len(kwargs['params']['output_shape']) + if output_shape == DIM_SEGMENT: + dim = 1 + elif output_shape > DIM_SEGMENT: + dim = len(kwargs['params']['output_shape']) - 2 + else: + return None return f"nn.BatchNorm{dim}d" @staticmethod diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py new file mode 100644 index 00000000..05099d38 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_sum_mapper.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class ReduceSumMapper(ONNXToMindSporeMapper): + """ReduceSum mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.ReduceSum" + + @staticmethod + def _convert_params(**kwargs): + params = kwargs['params'] + keep_dims = not params.get('keepdims', 1) == 0 + return {'keep_dims': keep_dims} + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() + + @staticmethod + def _generate_snippet_template(**kwargs): + op = kwargs.get("operation") + args = kwargs.get("converted_params") + raw_params = kwargs.get("raw_params") + if raw_params.get('axes'): + axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) + else: + axis = tuple() + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + + args["axis"] = axis + init_tensor = f"self.{{{variable_slot}}}_axis = {{axis}}" + + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ + f"self.{{{variable_slot}}}_axis)" + + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template, init_tensor], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: [], + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + } + } + outputs_list = [f"opt_{{{variable_slot}}}"] + outputs_mapping = ((0, 0),) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index 19dcaa9e..4f50831a 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -40,5 +40,6 @@ "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper", "onnx::CumSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cumsum_mapper.CumSumMapper", "onnx::Sin": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.sin_mapper.SinMapper", - "onnx::Cos": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cos_mapper.CosMapper" + "onnx::Cos": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cos_mapper.CosMapper", + "onnx::ReduceSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_sum_mapper.ReduceSumMapper" } \ No newline at end of file