diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/drop_out_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/drop_out_mapper.py new file mode 100644 index 00000000..622a90b4 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/drop_out_mapper.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class DropoutMapper(ONNXToMindSporeMapper): + """Dropout mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "nn.Dropout" + + @staticmethod + def _convert_params(**kwargs): + params = kwargs["params"] + if params.get("training_mode", False): + ratio = 1.0 - params.get('ratio', 0.5) + else: + ratio = 1.0 + return {'keep_prob': ratio} + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index c4e660cd..51ff84eb 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -41,5 +41,6 @@ "onnx::CumSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cumsum_mapper.CumSumMapper", "onnx::Sin": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.sin_mapper.SinMapper", "onnx::Cos": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.cos_mapper.CosMapper", - "onnx::ReduceSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_sum_mapper.ReduceSumMapper" + "onnx::ReduceSum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_sum_mapper.ReduceSumMapper", + "onnx::Dropout": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.drop_out_mapper.DropoutMapper" } \ No newline at end of file