Browse Source

Add CumSum op mapper.

pull/1247/head
fuyu0830 4 years ago
parent
commit
63511e778b
3 changed files with 134 additions and 1 deletions
  1. +79
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py
  2. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  3. +53
    -0
      tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py

+ 79
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/cumsum_mapper.py View File

@@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CumSum mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class CumSumMapper(ONNXToMindSporeMapper):
"""CumSum mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.CumSum"

@staticmethod
def _convert_params(**kwargs):
params = kwargs['params']
exclusive = params.get("exclusive", 0)
reverse = params.get("reverse", 0)
exclusive = exclusive == 1
reverse = reverse == 1
return {"exclusive": exclusive, "reverse": reverse}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())
raw_weights = kwargs.get("weights", list())
if not op or not raw_weights:
raise ValueError("Can not get MindSpore operation name.")

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"

axis = CumSumMapper._find_val_by_index(0, raw_weights)
args["axis"] = int(axis)
init_axis = f"self.{{{variable_slot}}}_axis = {{axis}}"

construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"self.{{{variable_slot}}}_axis)"

template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, init_axis],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}

exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict()
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -37,5 +37,6 @@
"onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper", "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper",
"onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper", "onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper",
"onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper", "onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper",
"onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper"
"onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper",
"onnx::CumSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cumsum_mapper.CumSumMapper"
} }

+ 53
- 0
tests/ut/mindconverter/graph_based_converter/mapper/test_cumsum.py View File

@@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test CumSum operator mapper."""
from unittest import TestCase
import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import NodeWeight


class TestCumSum(TestCase):
"""Tester of CumSun."""

def test_mapper(self):
"""Test code generation."""
onnx_info = {
"op_name": "onnx::CumSum",
"attributes": {},
"weights": [NodeWeight(weight_name="axis", weight_value=np.int32(0), weight_location=0)]
}
template, exchange_msg, outputs_lists, outputs_mapping = ONNXToMindSporeMapper.convert(
onnx_info['op_name'],
onnx_info['attributes'],
onnx_info['weights']
)
exchange_msg['var_0']['variable_name'] = 'cumsum_op'
exchange_msg['var_0']['inputs'] = ['x']

fragment = NewFragment(data_entity=exchange_msg, code_template=template, outputs=outputs_lists,
outputs_mapping=outputs_mapping)

code = fragment()
init_code = code[0]
construct_code = code[1]
self.assertEqual(init_code, [
"self.cumsum_op = P.CumSum(exclusive=False, reverse=False)",
"self.cumsum_op_axis = 0"
])
self.assertEqual(construct_code, [
"opt_cumsum_op = self.cumsum_op(x, self.cumsum_op_axis)"
])

Loading…
Cancel
Save