| @@ -40,6 +40,7 @@ | |||
| #include "dataset/kernels/data/fill_op.h" | |||
| #include "dataset/kernels/data/mask_op.h" | |||
| #include "dataset/kernels/data/slice_op.h" | |||
| #include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| @@ -384,7 +385,7 @@ void bindTensorOps2(py::module *m) { | |||
| *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | |||
| .def(py::init<std::shared_ptr<Tensor>>()); | |||
| (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor Slice operation.") | |||
| (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor slice operation.") | |||
| .def(py::init<bool>()) | |||
| .def(py::init([](const py::list &py_list) { | |||
| std::vector<dsize_t> c_list; | |||
| @@ -425,9 +426,13 @@ void bindTensorOps2(py::module *m) { | |||
| .export_values(); | |||
| (void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp", | |||
| "Tensor operation mask using relational comparator") | |||
| "Tensor mask operation using relational comparator") | |||
| .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>()); | |||
| (void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>( | |||
| *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") | |||
| .def(py::init<int64_t>()); | |||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | |||
| *m, "RandomRotationOp", | |||
| "Tensor operation to apply RandomRotation." | |||
| @@ -7,4 +7,5 @@ add_library(kernels-data OBJECT | |||
| to_float16_op.cc | |||
| fill_op.cc | |||
| slice_op.cc | |||
| mask_op.cc) | |||
| mask_op.cc | |||
| ) | |||
| @@ -33,7 +33,7 @@ Status MaskOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Ten | |||
| if (type_ != DataType::DE_BOOL) { | |||
| RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); | |||
| } else { | |||
| *output = temp_output; | |||
| *output = std::move(temp_output); | |||
| } | |||
| return Status::OK(); | |||
| @@ -17,5 +17,6 @@ add_library(text-kernels OBJECT | |||
| unicode_char_tokenizer_op.cc | |||
| ngram_op.cc | |||
| wordpiece_tokenizer_op.cc | |||
| truncate_sequence_pair_op.cc | |||
| ${ICU_DEPEND_FILES} | |||
| ) | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/text/kernels/truncate_sequence_pair_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/kernels/data/slice_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); | |||
| std::shared_ptr<Tensor> seq1 = input[0]; | |||
| std::shared_ptr<Tensor> seq2 = input[1]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, | |||
| "Both sequences should be of rank 1"); | |||
| dsize_t length1 = seq1->shape()[0]; | |||
| dsize_t length2 = seq2->shape()[0]; | |||
| dsize_t outLength1 = length1; | |||
| dsize_t outLength2 = length2; | |||
| dsize_t total = length1 + length2; | |||
| while (total > max_length_) { | |||
| if (outLength1 > outLength2) | |||
| outLength1--; | |||
| else | |||
| outLength2--; | |||
| total--; | |||
| } | |||
| std::shared_ptr<Tensor> outSeq1; | |||
| if (length1 != outLength1) { | |||
| std::unique_ptr<SliceOp> slice1(new SliceOp(Slice(outLength1 - length1))); | |||
| RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); | |||
| } else { | |||
| outSeq1 = std::move(seq1); | |||
| } | |||
| std::shared_ptr<Tensor> outSeq2; | |||
| if (length2 != outLength2) { | |||
| std::unique_ptr<SliceOp> slice2(new SliceOp(Slice(outLength2 - length2))); | |||
| RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); | |||
| } else { | |||
| outSeq2 = std::move(seq2); | |||
| } | |||
| output->push_back(outSeq1); | |||
| output->push_back(outSeq2); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ | |||
| #define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TruncateSequencePairOp : public TensorOp { | |||
| public: | |||
| explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} | |||
| ~TruncateSequencePairOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| private: | |||
| dsize_t max_length_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ | |||
| @@ -16,12 +16,12 @@ | |||
| mindspore.dataset.text | |||
| """ | |||
| import platform | |||
| from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer | |||
| from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair | |||
| from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm | |||
| __all__ = [ | |||
| "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", | |||
| "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer" | |||
| "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair" | |||
| ] | |||
| if platform.system().lower() != 'windows': | |||
| @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde | |||
| from .utils import JiebaMode, NormalizeForm | |||
| from .validators import check_lookup, check_jieba_add_dict, \ | |||
| check_jieba_add_word, check_jieba_init, check_ngram | |||
| check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate | |||
| class Lookup(cde.LookupOp): | |||
| @@ -344,3 +344,31 @@ if platform.system().lower() != 'windows': | |||
| self.preserve_unused_token = preserve_unused_token | |||
| super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, self.unknown_token, | |||
| self.lower_case, self.keep_whitespace, self.normalization_form, self.preserve_unused_token) | |||
| class TruncateSequencePair(cde.TruncateSequencePairOp): | |||
| """ | |||
| Truncate a pair of rank-1 tensors such that the total length is less than max_length. | |||
| This operation takes two input tensors and returns two output Tenors. | |||
| Args: | |||
| max_length(int): Maximum length required. | |||
| Examples: | |||
| >>> # Data before | |||
| >>> # | col1 | col2 | | |||
| >>> # +---------+---------| | |||
| >>> # | [1,2,3] | [4,5] | | |||
| >>> # +---------+---------+ | |||
| >>> data = data.map(operations=TruncateSequencePair(4)) | |||
| >>> # Data after | |||
| >>> # | col1 | col2 | | |||
| >>> # +---------+---------+ | |||
| >>> # | [1,2] | [4,5] | | |||
| >>> # +---------+---------+ | |||
| """ | |||
| @check_pair_truncate | |||
| def __init__(self, max_length): | |||
| super().__init__(max_length) | |||
| @@ -20,7 +20,7 @@ from functools import wraps | |||
| import mindspore._c_dataengine as cde | |||
| from ..transforms.validators import check_uint32 | |||
| from ..transforms.validators import check_uint32, check_pos_int64 | |||
| def check_lookup(method): | |||
| @@ -298,3 +298,22 @@ def check_ngram(method): | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_pair_truncate(method): | |||
| """Wrapper method to check the parameters of number of pair truncate.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| max_length = (list(args) + [None])[0] | |||
| if "max_length" in kwargs: | |||
| max_length = kwargs.get("max_length") | |||
| if max_length is None: | |||
| raise ValueError("max_length is not provided.") | |||
| check_pos_int64(max_length) | |||
| kwargs["max_length"] = max_length | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -216,7 +216,7 @@ def check_slice_op(method): | |||
| def check_mask_op(method): | |||
| """Wrapper method to check the parameters of slice.""" | |||
| """Wrapper method to check the parameters of mask.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "securec.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" | |||
| using namespace mindspore::dataset; | |||
| namespace py = pybind11; | |||
| class MindDataTestTruncatePairOp : public UT::Common { | |||
| public: | |||
| MindDataTestTruncatePairOp() {} | |||
| void SetUp() { GlobalInit(); } | |||
| }; | |||
| TEST_F(MindDataTestTruncatePairOp, Basics) { | |||
| std::shared_ptr<Tensor> t1; | |||
| Tensor::CreateTensor(&t1, std::vector<uint32_t>({1, 2, 3})); | |||
| std::shared_ptr<Tensor> t2; | |||
| Tensor::CreateTensor(&t2, std::vector<uint32_t>({4, 5})); | |||
| TensorRow in({t1, t2}); | |||
| std::shared_ptr<TruncateSequencePairOp> op = std::make_shared<TruncateSequencePairOp>(4); | |||
| TensorRow out; | |||
| ASSERT_TRUE(op->Compute(in, &out).IsOk()); | |||
| std::shared_ptr<Tensor> out1; | |||
| Tensor::CreateTensor(&out1, std::vector<uint32_t>({1, 2})); | |||
| std::shared_ptr<Tensor> out2; | |||
| Tensor::CreateTensor(&out2, std::vector<uint32_t>({4, 5})); | |||
| ASSERT_EQ(*out1, *out[0]); | |||
| ASSERT_EQ(*out2, *out[1]); | |||
| } | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing Mask op in DE | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| def compare(in1, in2, length, out1, out2): | |||
| data = ds.NumpySlicesDataset({"s1": [in1], "s2": [in2]}) | |||
| data = data.map(input_columns=["s1", "s2"], operations=text.TruncateSequencePair(length)) | |||
| for d in data.create_dict_iterator(): | |||
| np.testing.assert_array_equal(out1, d["s1"]) | |||
| np.testing.assert_array_equal(out2, d["s2"]) | |||
| def test_basics(): | |||
| compare(in1=[1, 2, 3], in2=[4, 5], length=4, out1=[1, 2], out2=[4, 5]) | |||
| compare(in1=[1, 2], in2=[4, 5], length=4, out1=[1, 2], out2=[4, 5]) | |||
| compare(in1=[1], in2=[4], length=4, out1=[1], out2=[4]) | |||
| compare(in1=[1, 2, 3, 4], in2=[5], length=4, out1=[1, 2, 3], out2=[5]) | |||
| compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=4, out1=[1, 2], out2=[5, 6]) | |||
| def test_basics_odd(): | |||
| compare(in1=[1, 2, 3], in2=[4, 5], length=3, out1=[1, 2], out2=[4]) | |||
| compare(in1=[1, 2], in2=[4, 5], length=3, out1=[1, 2], out2=[4]) | |||
| compare(in1=[1], in2=[4], length=5, out1=[1], out2=[4]) | |||
| compare(in1=[1, 2, 3, 4], in2=[5], length=3, out1=[1, 2], out2=[5]) | |||
| compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=3, out1=[1, 2], out2=[5]) | |||
| def test_basics_str(): | |||
| compare(in1=[b"1", b"2", b"3"], in2=[4, 5], length=4, out1=[b"1", b"2"], out2=[4, 5]) | |||
| compare(in1=[b"1", b"2"], in2=[b"4", b"5"], length=4, out1=[b"1", b"2"], out2=[b"4", b"5"]) | |||
| compare(in1=[b"1"], in2=[4], length=4, out1=[b"1"], out2=[4]) | |||
| compare(in1=[b"1", b"2", b"3", b"4"], in2=[b"5"], length=4, out1=[b"1", b"2", b"3"], out2=[b"5"]) | |||
| compare(in1=[b"1", b"2", b"3", b"4"], in2=[5, 6, 7, 8], length=4, out1=[b"1", b"2"], out2=[5, 6]) | |||
| def test_exceptions(): | |||
| with pytest.raises(RuntimeError) as info: | |||
| compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1, 2], out2=[5]) | |||
| assert "Indices are empty, generated tensor would be empty" in str(info.value) | |||
| if __name__ == "__main__": | |||
| test_basics() | |||
| test_basics_odd() | |||
| test_basics_str() | |||
| test_exceptions() | |||