|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- import mindspore.dataset as ds
-
-
- def test_clue():
- """
- Test CLUE with repeat, skip and so on
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
-
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
- data = data.repeat(2)
- data = data.skip(3)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
-
- def test_clue_num_shards():
- """
- Test num_shards param of CLUE dataset
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
-
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 1
-
-
- def test_clue_num_samples():
- """
- Test num_samples param of CLUE dataset
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
-
- data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
- count = 0
- for _ in data.create_dict_iterator():
- count += 1
- assert count == 2
-
-
- def test_textline_dataset_get_datasetsize():
- """
- Test get_dataset_size of CLUE dataset
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
-
- data = ds.TextFileDataset(TRAIN_FILE)
- size = data.get_dataset_size()
- assert size == 3
-
-
- def test_clue_afqmc():
- """
- Test AFQMC for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
- TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # evaluation
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
-
- def test_clue_cmnli():
- """
- Test CMNLI for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
- TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'sentence1': d['sentence1'].item().decode("utf8"),
- 'sentence2': d['sentence2'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'sentence1': d['sentence1'],
- 'sentence2': d['sentence2']
- })
- assert len(buffer) == 3
-
- # eval
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'],
- 'sentence1': d['sentence1'],
- 'sentence2': d['sentence2']
- })
- assert len(buffer) == 3
-
-
- def test_clue_csl():
- """
- Test CSL for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
- TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'abst': d['abst'].item().decode("utf8"),
- 'keyword': [i.item().decode("utf8") for i in d['keyword']],
- 'label': d['label'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'abst': d['abst'].item().decode("utf8"),
- 'keyword': [i.item().decode("utf8") for i in d['keyword']],
- })
- assert len(buffer) == 3
-
- # eval
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'abst': d['abst'].item().decode("utf8"),
- 'keyword': [i.item().decode("utf8") for i in d['keyword']],
- 'label': d['label'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
-
- def test_clue_iflytek():
- """
- Test IFLYTEK for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
- TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'label_des': d['label_des'].item().decode("utf8"),
- 'sentence': d['sentence'].item().decode("utf8"),
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'sentence': d['sentence'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # eval
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'label_des': d['label_des'].item().decode("utf8"),
- 'sentence': d['sentence'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
-
- def test_clue_tnews():
- """
- Test TNEWS for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
- TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'label_desc': d['label_desc'].item().decode("utf8"),
- 'sentence': d['sentence'].item().decode("utf8"),
- 'keywords':
- d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'id': d['id'],
- 'sentence': d['sentence'].item().decode("utf8"),
- 'keywords':
- d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
- })
- assert len(buffer) == 3
-
- # eval
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
- for d in data.create_dict_iterator():
- buffer.append({
- 'label': d['label'].item().decode("utf8"),
- 'label_desc': d['label_desc'].item().decode("utf8"),
- 'sentence': d['sentence'].item().decode("utf8"),
- 'keywords':
- d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
- })
- assert len(buffer) == 3
-
-
- def test_clue_wsc():
- """
- Test WSC for train, test and evaluation
- """
- TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
- TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
- EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
-
- # train
- buffer = []
- data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
- for d in data.create_dict_iterator():
- buffer.append({
- 'span1_index': d['span1_index'],
- 'span2_index': d['span2_index'],
- 'span1_text': d['span1_text'].item().decode("utf8"),
- 'span2_text': d['span2_text'].item().decode("utf8"),
- 'idx': d['idx'],
- 'label': d['label'].item().decode("utf8"),
- 'text': d['text'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # test
- buffer = []
- data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
- for d in data.create_dict_iterator():
- buffer.append({
- 'span1_index': d['span1_index'],
- 'span2_index': d['span2_index'],
- 'span1_text': d['span1_text'].item().decode("utf8"),
- 'span2_text': d['span2_text'].item().decode("utf8"),
- 'idx': d['idx'],
- 'text': d['text'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- # eval
- buffer = []
- data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
- for d in data.create_dict_iterator():
- buffer.append({
- 'span1_index': d['span1_index'],
- 'span2_index': d['span2_index'],
- 'span1_text': d['span1_text'].item().decode("utf8"),
- 'span2_text': d['span2_text'].item().decode("utf8"),
- 'idx': d['idx'],
- 'label': d['label'].item().decode("utf8"),
- 'text': d['text'].item().decode("utf8")
- })
- assert len(buffer) == 3
-
- def test_clue_to_device():
- """
- Test CLUE with to_device
- """
- TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
- data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
- data = data.to_device()
- data.send()
-
-
- if __name__ == "__main__":
- test_clue()
- test_clue_num_shards()
- test_clue_num_samples()
- test_textline_dataset_get_datasetsize()
- test_clue_afqmc()
- test_clue_cmnli()
- test_clue_csl()
- test_clue_iflytek()
- test_clue_tnews()
- test_clue_wsc()
- test_clue_to_device()
|