From 63511e778b4f1fe6b2a3c21fe6280b6092d7a582 Mon Sep 17 00:00:00 2001 From: fuyu0830 <1534155682@qq.com> Date: Sat, 20 Mar 2021 17:06:16 +0800 Subject: [PATCH] Add CumSum op mapper. --- .../mapper/impl/ops/cumsum_mapper.py | 79 +++++++++++++++++++ .../mapper/onnx_to_ms.json | 3 +- .../mapper/test_cumsum.py | 53 +++++++++++++ 3 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py create mode 100644 tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py new file mode 100644 index 00000000..9789c172 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py @@ -0,0 +1,79 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CumSum mapper module.""" +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class CumSumMapper(ONNXToMindSporeMapper): + """CumSum mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.CumSum" + + @staticmethod + def _convert_params(**kwargs): + params = kwargs['params'] + exclusive = params.get("exclusive", 0) + reverse = params.get("reverse", 0) + exclusive = exclusive == 1 + reverse = reverse == 1 + return {"exclusive": exclusive, "reverse": reverse} + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() + + @staticmethod + def _generate_snippet_template(**kwargs): + op = kwargs.get("operation") + args = kwargs.get("converted_params", dict()) + raw_weights = kwargs.get("weights", list()) + if not op or not raw_weights: + raise ValueError("Can not get MindSpore operation name.") + + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + + axis = CumSumMapper._find_val_by_index(0, raw_weights) + args["axis"] = int(axis) + init_axis = f"self.{{{variable_slot}}}_axis = {{axis}}" + + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ + f"self.{{{variable_slot}}}_axis)" + + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template, init_axis], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict() + } + } + outputs_list = [f"opt_{{{variable_slot}}}"] + outputs_mapping = ((0, 0),) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index 589499ef..edb57157 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -37,5 +37,6 @@ "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper", "onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper", "onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper", - "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper" + "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper", + "onnx::CumSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cumsum_mapper.CumSumMapper" } \ No newline at end of file diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py new file mode 100644 index 00000000..1026d51e --- /dev/null +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test CumSum operator mapper.""" +from unittest import TestCase +import numpy as np +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import NodeWeight + + +class TestCumSum(TestCase): + """Tester of CumSun.""" + + def test_mapper(self): + """Test code generation.""" + onnx_info = { + "op_name": "onnx::CumSum", + "attributes": {}, + "weights": [NodeWeight(weight_name="axis", weight_value=np.int32(0), weight_location=0)] + } + template, exchange_msg, outputs_lists, outputs_mapping = ONNXToMindSporeMapper.convert( + onnx_info['op_name'], + onnx_info['attributes'], + onnx_info['weights'] + ) + exchange_msg['var_0']['variable_name'] = 'cumsum_op' + exchange_msg['var_0']['inputs'] = ['x'] + + fragment = NewFragment(data_entity=exchange_msg, code_template=template, outputs=outputs_lists, + outputs_mapping=outputs_mapping) + + code = fragment() + init_code = code[0] + construct_code = code[1] + self.assertEqual(init_code, [ + "self.cumsum_op = P.CumSum(exclusive=False, reverse=False)", + "self.cumsum_op_axis = 0" + ]) + self.assertEqual(construct_code, [ + "opt_cumsum_op = self.cumsum_op(x, self.cumsum_op_axis)" + ])