You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_clue.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. def test_clue():
  17. """
  18. Test CLUE with repeat, skip and so on
  19. """
  20. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  21. buffer = []
  22. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  23. data = data.repeat(2)
  24. data = data.skip(3)
  25. for d in data.create_dict_iterator():
  26. buffer.append({
  27. 'label': d['label'].item().decode("utf8"),
  28. 'sentence1': d['sentence1'].item().decode("utf8"),
  29. 'sentence2': d['sentence2'].item().decode("utf8")
  30. })
  31. assert len(buffer) == 3
  32. def test_clue_num_shards():
  33. """
  34. Test num_shards param of CLUE dataset
  35. """
  36. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  37. buffer = []
  38. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
  39. for d in data.create_dict_iterator():
  40. buffer.append({
  41. 'label': d['label'].item().decode("utf8"),
  42. 'sentence1': d['sentence1'].item().decode("utf8"),
  43. 'sentence2': d['sentence2'].item().decode("utf8")
  44. })
  45. assert len(buffer) == 1
  46. def test_clue_num_samples():
  47. """
  48. Test num_samples param of CLUE dataset
  49. """
  50. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  51. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
  52. count = 0
  53. for _ in data.create_dict_iterator():
  54. count += 1
  55. assert count == 2
  56. def test_textline_dataset_get_datasetsize():
  57. """
  58. Test get_dataset_size of CLUE dataset
  59. """
  60. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  61. data = ds.TextFileDataset(TRAIN_FILE)
  62. size = data.get_dataset_size()
  63. assert size == 3
  64. def test_clue_afqmc():
  65. """
  66. Test AFQMC for train, test and evaluation
  67. """
  68. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  69. TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
  70. EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
  71. # train
  72. buffer = []
  73. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  74. for d in data.create_dict_iterator():
  75. buffer.append({
  76. 'label': d['label'].item().decode("utf8"),
  77. 'sentence1': d['sentence1'].item().decode("utf8"),
  78. 'sentence2': d['sentence2'].item().decode("utf8")
  79. })
  80. assert len(buffer) == 3
  81. # test
  82. buffer = []
  83. data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
  84. for d in data.create_dict_iterator():
  85. buffer.append({
  86. 'id': d['id'],
  87. 'sentence1': d['sentence1'].item().decode("utf8"),
  88. 'sentence2': d['sentence2'].item().decode("utf8")
  89. })
  90. assert len(buffer) == 3
  91. # evaluation
  92. buffer = []
  93. data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
  94. for d in data.create_dict_iterator():
  95. buffer.append({
  96. 'label': d['label'].item().decode("utf8"),
  97. 'sentence1': d['sentence1'].item().decode("utf8"),
  98. 'sentence2': d['sentence2'].item().decode("utf8")
  99. })
  100. assert len(buffer) == 3
  101. def test_clue_cmnli():
  102. """
  103. Test CMNLI for train, test and evaluation
  104. """
  105. TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
  106. TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
  107. EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
  108. # train
  109. buffer = []
  110. data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
  111. for d in data.create_dict_iterator():
  112. buffer.append({
  113. 'label': d['label'].item().decode("utf8"),
  114. 'sentence1': d['sentence1'].item().decode("utf8"),
  115. 'sentence2': d['sentence2'].item().decode("utf8")
  116. })
  117. assert len(buffer) == 3
  118. # test
  119. buffer = []
  120. data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
  121. for d in data.create_dict_iterator():
  122. buffer.append({
  123. 'id': d['id'],
  124. 'sentence1': d['sentence1'],
  125. 'sentence2': d['sentence2']
  126. })
  127. assert len(buffer) == 3
  128. # eval
  129. buffer = []
  130. data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
  131. for d in data.create_dict_iterator():
  132. buffer.append({
  133. 'label': d['label'],
  134. 'sentence1': d['sentence1'],
  135. 'sentence2': d['sentence2']
  136. })
  137. assert len(buffer) == 3
  138. def test_clue_csl():
  139. """
  140. Test CSL for train, test and evaluation
  141. """
  142. TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
  143. TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
  144. EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
  145. # train
  146. buffer = []
  147. data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
  148. for d in data.create_dict_iterator():
  149. buffer.append({
  150. 'id': d['id'],
  151. 'abst': d['abst'].item().decode("utf8"),
  152. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  153. 'label': d['label'].item().decode("utf8")
  154. })
  155. assert len(buffer) == 3
  156. # test
  157. buffer = []
  158. data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
  159. for d in data.create_dict_iterator():
  160. buffer.append({
  161. 'id': d['id'],
  162. 'abst': d['abst'].item().decode("utf8"),
  163. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  164. })
  165. assert len(buffer) == 3
  166. # eval
  167. buffer = []
  168. data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
  169. for d in data.create_dict_iterator():
  170. buffer.append({
  171. 'id': d['id'],
  172. 'abst': d['abst'].item().decode("utf8"),
  173. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  174. 'label': d['label'].item().decode("utf8")
  175. })
  176. assert len(buffer) == 3
  177. def test_clue_iflytek():
  178. """
  179. Test IFLYTEK for train, test and evaluation
  180. """
  181. TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
  182. TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
  183. EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
  184. # train
  185. buffer = []
  186. data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
  187. for d in data.create_dict_iterator():
  188. buffer.append({
  189. 'label': d['label'].item().decode("utf8"),
  190. 'label_des': d['label_des'].item().decode("utf8"),
  191. 'sentence': d['sentence'].item().decode("utf8"),
  192. })
  193. assert len(buffer) == 3
  194. # test
  195. buffer = []
  196. data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
  197. for d in data.create_dict_iterator():
  198. buffer.append({
  199. 'id': d['id'],
  200. 'sentence': d['sentence'].item().decode("utf8")
  201. })
  202. assert len(buffer) == 3
  203. # eval
  204. buffer = []
  205. data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
  206. for d in data.create_dict_iterator():
  207. buffer.append({
  208. 'label': d['label'].item().decode("utf8"),
  209. 'label_des': d['label_des'].item().decode("utf8"),
  210. 'sentence': d['sentence'].item().decode("utf8")
  211. })
  212. assert len(buffer) == 3
  213. def test_clue_tnews():
  214. """
  215. Test TNEWS for train, test and evaluation
  216. """
  217. TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
  218. TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
  219. EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
  220. # train
  221. buffer = []
  222. data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
  223. for d in data.create_dict_iterator():
  224. buffer.append({
  225. 'label': d['label'].item().decode("utf8"),
  226. 'label_desc': d['label_desc'].item().decode("utf8"),
  227. 'sentence': d['sentence'].item().decode("utf8"),
  228. 'keywords':
  229. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  230. })
  231. assert len(buffer) == 3
  232. # test
  233. buffer = []
  234. data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
  235. for d in data.create_dict_iterator():
  236. buffer.append({
  237. 'id': d['id'],
  238. 'sentence': d['sentence'].item().decode("utf8"),
  239. 'keywords':
  240. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  241. })
  242. assert len(buffer) == 3
  243. # eval
  244. buffer = []
  245. data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
  246. for d in data.create_dict_iterator():
  247. buffer.append({
  248. 'label': d['label'].item().decode("utf8"),
  249. 'label_desc': d['label_desc'].item().decode("utf8"),
  250. 'sentence': d['sentence'].item().decode("utf8"),
  251. 'keywords':
  252. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  253. })
  254. assert len(buffer) == 3
  255. def test_clue_wsc():
  256. """
  257. Test WSC for train, test and evaluation
  258. """
  259. TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
  260. TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
  261. EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
  262. # train
  263. buffer = []
  264. data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
  265. for d in data.create_dict_iterator():
  266. buffer.append({
  267. 'span1_index': d['span1_index'],
  268. 'span2_index': d['span2_index'],
  269. 'span1_text': d['span1_text'].item().decode("utf8"),
  270. 'span2_text': d['span2_text'].item().decode("utf8"),
  271. 'idx': d['idx'],
  272. 'label': d['label'].item().decode("utf8"),
  273. 'text': d['text'].item().decode("utf8")
  274. })
  275. assert len(buffer) == 3
  276. # test
  277. buffer = []
  278. data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
  279. for d in data.create_dict_iterator():
  280. buffer.append({
  281. 'span1_index': d['span1_index'],
  282. 'span2_index': d['span2_index'],
  283. 'span1_text': d['span1_text'].item().decode("utf8"),
  284. 'span2_text': d['span2_text'].item().decode("utf8"),
  285. 'idx': d['idx'],
  286. 'text': d['text'].item().decode("utf8")
  287. })
  288. assert len(buffer) == 3
  289. # eval
  290. buffer = []
  291. data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
  292. for d in data.create_dict_iterator():
  293. buffer.append({
  294. 'span1_index': d['span1_index'],
  295. 'span2_index': d['span2_index'],
  296. 'span1_text': d['span1_text'].item().decode("utf8"),
  297. 'span2_text': d['span2_text'].item().decode("utf8"),
  298. 'idx': d['idx'],
  299. 'label': d['label'].item().decode("utf8"),
  300. 'text': d['text'].item().decode("utf8")
  301. })
  302. assert len(buffer) == 3
  303. def test_clue_to_device():
  304. """
  305. Test CLUE with to_device
  306. """
  307. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  308. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  309. data = data.to_device()
  310. data.send()
  311. if __name__ == "__main__":
  312. test_clue()
  313. test_clue_num_shards()
  314. test_clue_num_samples()
  315. test_textline_dataset_get_datasetsize()
  316. test_clue_afqmc()
  317. test_clue_cmnli()
  318. test_clue_csl()
  319. test_clue_iflytek()
  320. test_clue_tnews()
  321. test_clue_wsc()
  322. test_clue_to_device()