You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imagefolder.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. DATA_DIR = "../data/dataset/testPK/data"
  18. def test_imagefolder_basic():
  19. logger.info("Test Case basic")
  20. # define parameters
  21. repeat_count = 1
  22. # apply dataset operations
  23. data1 = ds.ImageFolderDatasetV2(DATA_DIR)
  24. data1 = data1.repeat(repeat_count)
  25. num_iter = 0
  26. for item in data1.create_dict_iterator(): # each data is a dictionary
  27. # in this example, each dictionary has keys "image" and "label"
  28. logger.info("image is {}".format(item["image"]))
  29. logger.info("label is {}".format(item["label"]))
  30. num_iter += 1
  31. logger.info("Number of data in data1: {}".format(num_iter))
  32. assert num_iter == 44
  33. def test_imagefolder_numsamples():
  34. logger.info("Test Case numSamples")
  35. # define parameters
  36. repeat_count = 1
  37. # apply dataset operations
  38. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2)
  39. data1 = data1.repeat(repeat_count)
  40. num_iter = 0
  41. for item in data1.create_dict_iterator(): # each data is a dictionary
  42. # in this example, each dictionary has keys "image" and "label"
  43. logger.info("image is {}".format(item["image"]))
  44. logger.info("label is {}".format(item["label"]))
  45. num_iter += 1
  46. logger.info("Number of data in data1: {}".format(num_iter))
  47. assert num_iter == 10
  48. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  49. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  50. num_iter = 0
  51. for item in data1.create_dict_iterator():
  52. num_iter += 1
  53. assert num_iter == 3
  54. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  55. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  56. num_iter = 0
  57. for item in data1.create_dict_iterator():
  58. num_iter += 1
  59. assert num_iter == 3
  60. def test_imagefolder_numshards():
  61. logger.info("Test Case numShards")
  62. # define parameters
  63. repeat_count = 1
  64. # apply dataset operations
  65. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
  66. data1 = data1.repeat(repeat_count)
  67. num_iter = 0
  68. for item in data1.create_dict_iterator(): # each data is a dictionary
  69. # in this example, each dictionary has keys "image" and "label"
  70. logger.info("image is {}".format(item["image"]))
  71. logger.info("label is {}".format(item["label"]))
  72. num_iter += 1
  73. logger.info("Number of data in data1: {}".format(num_iter))
  74. assert num_iter == 11
  75. def test_imagefolder_shardid():
  76. logger.info("Test Case withShardID")
  77. # define parameters
  78. repeat_count = 1
  79. # apply dataset operations
  80. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=1)
  81. data1 = data1.repeat(repeat_count)
  82. num_iter = 0
  83. for item in data1.create_dict_iterator(): # each data is a dictionary
  84. # in this example, each dictionary has keys "image" and "label"
  85. logger.info("image is {}".format(item["image"]))
  86. logger.info("label is {}".format(item["label"]))
  87. num_iter += 1
  88. logger.info("Number of data in data1: {}".format(num_iter))
  89. assert num_iter == 11
  90. def test_imagefolder_noshuffle():
  91. logger.info("Test Case noShuffle")
  92. # define parameters
  93. repeat_count = 1
  94. # apply dataset operations
  95. data1 = ds.ImageFolderDatasetV2(DATA_DIR, shuffle=False)
  96. data1 = data1.repeat(repeat_count)
  97. num_iter = 0
  98. for item in data1.create_dict_iterator(): # each data is a dictionary
  99. # in this example, each dictionary has keys "image" and "label"
  100. logger.info("image is {}".format(item["image"]))
  101. logger.info("label is {}".format(item["label"]))
  102. num_iter += 1
  103. logger.info("Number of data in data1: {}".format(num_iter))
  104. assert num_iter == 44
  105. def test_imagefolder_extrashuffle():
  106. logger.info("Test Case extraShuffle")
  107. # define parameters
  108. repeat_count = 2
  109. # apply dataset operations
  110. data1 = ds.ImageFolderDatasetV2(DATA_DIR, shuffle=True)
  111. data1 = data1.shuffle(buffer_size=5)
  112. data1 = data1.repeat(repeat_count)
  113. num_iter = 0
  114. for item in data1.create_dict_iterator(): # each data is a dictionary
  115. # in this example, each dictionary has keys "image" and "label"
  116. logger.info("image is {}".format(item["image"]))
  117. logger.info("label is {}".format(item["label"]))
  118. num_iter += 1
  119. logger.info("Number of data in data1: {}".format(num_iter))
  120. assert num_iter == 88
  121. def test_imagefolder_classindex():
  122. logger.info("Test Case classIndex")
  123. # define parameters
  124. repeat_count = 1
  125. # apply dataset operations
  126. class_index = {"class3": 333, "class1": 111}
  127. data1 = ds.ImageFolderDatasetV2(DATA_DIR, class_indexing=class_index, shuffle=False)
  128. data1 = data1.repeat(repeat_count)
  129. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  130. 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
  131. num_iter = 0
  132. for item in data1.create_dict_iterator(): # each data is a dictionary
  133. # in this example, each dictionary has keys "image" and "label"
  134. logger.info("image is {}".format(item["image"]))
  135. logger.info("label is {}".format(item["label"]))
  136. assert item["label"] == golden[num_iter]
  137. num_iter += 1
  138. logger.info("Number of data in data1: {}".format(num_iter))
  139. assert num_iter == 22
  140. def test_imagefolder_negative_classindex():
  141. logger.info("Test Case negative classIndex")
  142. # define parameters
  143. repeat_count = 1
  144. # apply dataset operations
  145. class_index = {"class3": -333, "class1": 111}
  146. data1 = ds.ImageFolderDatasetV2(DATA_DIR, class_indexing=class_index, shuffle=False)
  147. data1 = data1.repeat(repeat_count)
  148. golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
  149. -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
  150. num_iter = 0
  151. for item in data1.create_dict_iterator(): # each data is a dictionary
  152. # in this example, each dictionary has keys "image" and "label"
  153. logger.info("image is {}".format(item["image"]))
  154. logger.info("label is {}".format(item["label"]))
  155. assert item["label"] == golden[num_iter]
  156. num_iter += 1
  157. logger.info("Number of data in data1: {}".format(num_iter))
  158. assert num_iter == 22
  159. def test_imagefolder_extensions():
  160. logger.info("Test Case extensions")
  161. # define parameters
  162. repeat_count = 1
  163. # apply dataset operations
  164. ext = [".jpg", ".JPEG"]
  165. data1 = ds.ImageFolderDatasetV2(DATA_DIR, extensions=ext)
  166. data1 = data1.repeat(repeat_count)
  167. num_iter = 0
  168. for item in data1.create_dict_iterator(): # each data is a dictionary
  169. # in this example, each dictionary has keys "image" and "label"
  170. logger.info("image is {}".format(item["image"]))
  171. logger.info("label is {}".format(item["label"]))
  172. num_iter += 1
  173. logger.info("Number of data in data1: {}".format(num_iter))
  174. assert num_iter == 44
  175. def test_imagefolder_decode():
  176. logger.info("Test Case decode")
  177. # define parameters
  178. repeat_count = 1
  179. # apply dataset operations
  180. ext = [".jpg", ".JPEG"]
  181. data1 = ds.ImageFolderDatasetV2(DATA_DIR, extensions=ext, decode=True)
  182. data1 = data1.repeat(repeat_count)
  183. num_iter = 0
  184. for item in data1.create_dict_iterator(): # each data is a dictionary
  185. # in this example, each dictionary has keys "image" and "label"
  186. logger.info("image is {}".format(item["image"]))
  187. logger.info("label is {}".format(item["label"]))
  188. num_iter += 1
  189. logger.info("Number of data in data1: {}".format(num_iter))
  190. assert num_iter == 44
  191. def test_sequential_sampler():
  192. logger.info("Test Case SequentialSampler")
  193. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  194. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  195. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  196. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
  197. # define parameters
  198. repeat_count = 1
  199. # apply dataset operations
  200. sampler = ds.SequentialSampler()
  201. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  202. data1 = data1.repeat(repeat_count)
  203. result = []
  204. num_iter = 0
  205. for item in data1.create_dict_iterator(): # each data is a dictionary
  206. # in this example, each dictionary has keys "image" and "label"
  207. result.append(item["label"])
  208. num_iter += 1
  209. logger.info("Result: {}".format(result))
  210. assert result == golden
  211. def test_random_sampler():
  212. logger.info("Test Case RandomSampler")
  213. # define parameters
  214. repeat_count = 1
  215. # apply dataset operations
  216. sampler = ds.RandomSampler()
  217. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  218. data1 = data1.repeat(repeat_count)
  219. num_iter = 0
  220. for item in data1.create_dict_iterator(): # each data is a dictionary
  221. # in this example, each dictionary has keys "image" and "label"
  222. logger.info("image is {}".format(item["image"]))
  223. logger.info("label is {}".format(item["label"]))
  224. num_iter += 1
  225. logger.info("Number of data in data1: {}".format(num_iter))
  226. assert num_iter == 44
  227. def test_distributed_sampler():
  228. logger.info("Test Case DistributedSampler")
  229. # define parameters
  230. repeat_count = 1
  231. # apply dataset operations
  232. sampler = ds.DistributedSampler(10, 1)
  233. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  234. data1 = data1.repeat(repeat_count)
  235. num_iter = 0
  236. for item in data1.create_dict_iterator(): # each data is a dictionary
  237. # in this example, each dictionary has keys "image" and "label"
  238. logger.info("image is {}".format(item["image"]))
  239. logger.info("label is {}".format(item["label"]))
  240. num_iter += 1
  241. logger.info("Number of data in data1: {}".format(num_iter))
  242. assert num_iter == 5
  243. def test_pk_sampler():
  244. logger.info("Test Case PKSampler")
  245. # define parameters
  246. repeat_count = 1
  247. # apply dataset operations
  248. sampler = ds.PKSampler(3)
  249. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  250. data1 = data1.repeat(repeat_count)
  251. num_iter = 0
  252. for item in data1.create_dict_iterator(): # each data is a dictionary
  253. # in this example, each dictionary has keys "image" and "label"
  254. logger.info("image is {}".format(item["image"]))
  255. logger.info("label is {}".format(item["label"]))
  256. num_iter += 1
  257. logger.info("Number of data in data1: {}".format(num_iter))
  258. assert num_iter == 12
  259. def test_subset_random_sampler():
  260. logger.info("Test Case SubsetRandomSampler")
  261. # define parameters
  262. repeat_count = 1
  263. # apply dataset operations
  264. indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
  265. sampler = ds.SubsetRandomSampler(indices)
  266. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  267. data1 = data1.repeat(repeat_count)
  268. num_iter = 0
  269. for item in data1.create_dict_iterator(): # each data is a dictionary
  270. # in this example, each dictionary has keys "image" and "label"
  271. logger.info("image is {}".format(item["image"]))
  272. logger.info("label is {}".format(item["label"]))
  273. num_iter += 1
  274. logger.info("Number of data in data1: {}".format(num_iter))
  275. assert num_iter == 12
  276. def test_weighted_random_sampler():
  277. logger.info("Test Case WeightedRandomSampler")
  278. # define parameters
  279. repeat_count = 1
  280. # apply dataset operations
  281. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  282. sampler = ds.WeightedRandomSampler(weights, 11)
  283. data1 = ds.ImageFolderDatasetV2(DATA_DIR, sampler=sampler)
  284. data1 = data1.repeat(repeat_count)
  285. num_iter = 0
  286. for item in data1.create_dict_iterator(): # each data is a dictionary
  287. # in this example, each dictionary has keys "image" and "label"
  288. logger.info("image is {}".format(item["image"]))
  289. logger.info("label is {}".format(item["label"]))
  290. num_iter += 1
  291. logger.info("Number of data in data1: {}".format(num_iter))
  292. assert num_iter == 11
  293. def test_imagefolder_rename():
  294. logger.info("Test Case rename")
  295. # define parameters
  296. repeat_count = 1
  297. # apply dataset operations
  298. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  299. data1 = data1.repeat(repeat_count)
  300. num_iter = 0
  301. for item in data1.create_dict_iterator(): # each data is a dictionary
  302. # in this example, each dictionary has keys "image" and "label"
  303. logger.info("image is {}".format(item["image"]))
  304. logger.info("label is {}".format(item["label"]))
  305. num_iter += 1
  306. logger.info("Number of data in data1: {}".format(num_iter))
  307. assert num_iter == 10
  308. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  309. num_iter = 0
  310. for item in data1.create_dict_iterator(): # each data is a dictionary
  311. # in this example, each dictionary has keys "image" and "label"
  312. logger.info("image is {}".format(item["image2"]))
  313. logger.info("label is {}".format(item["label"]))
  314. num_iter += 1
  315. logger.info("Number of data in data1: {}".format(num_iter))
  316. assert num_iter == 10
  317. def test_imagefolder_zip():
  318. logger.info("Test Case zip")
  319. # define parameters
  320. repeat_count = 2
  321. # apply dataset operations
  322. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  323. data2 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10)
  324. data1 = data1.repeat(repeat_count)
  325. # rename dataset2 for no conflict
  326. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  327. data3 = ds.zip((data1, data2))
  328. num_iter = 0
  329. for item in data3.create_dict_iterator(): # each data is a dictionary
  330. # in this example, each dictionary has keys "image" and "label"
  331. logger.info("image is {}".format(item["image"]))
  332. logger.info("label is {}".format(item["label"]))
  333. num_iter += 1
  334. logger.info("Number of data in data1: {}".format(num_iter))
  335. assert num_iter == 10
  336. if __name__ == '__main__':
  337. test_imagefolder_basic()
  338. logger.info('test_imagefolder_basic Ended.\n')
  339. test_imagefolder_numsamples()
  340. logger.info('test_imagefolder_numsamples Ended.\n')
  341. test_sequential_sampler()
  342. logger.info('test_sequential_sampler Ended.\n')
  343. test_random_sampler()
  344. logger.info('test_random_sampler Ended.\n')
  345. test_distributed_sampler()
  346. logger.info('test_distributed_sampler Ended.\n')
  347. test_pk_sampler()
  348. logger.info('test_pk_sampler Ended.\n')
  349. test_subset_random_sampler()
  350. logger.info('test_subset_random_sampler Ended.\n')
  351. test_weighted_random_sampler()
  352. logger.info('test_weighted_random_sampler Ended.\n')
  353. test_imagefolder_numshards()
  354. logger.info('test_imagefolder_numshards Ended.\n')
  355. test_imagefolder_shardid()
  356. logger.info('test_imagefolder_shardid Ended.\n')
  357. test_imagefolder_noshuffle()
  358. logger.info('test_imagefolder_noshuffle Ended.\n')
  359. test_imagefolder_extrashuffle()
  360. logger.info('test_imagefolder_extrashuffle Ended.\n')
  361. test_imagefolder_classindex()
  362. logger.info('test_imagefolder_classindex Ended.\n')
  363. test_imagefolder_negative_classindex()
  364. logger.info('test_imagefolder_negative_classindex Ended.\n')
  365. test_imagefolder_extensions()
  366. logger.info('test_imagefolder_extensions Ended.\n')
  367. test_imagefolder_decode()
  368. logger.info('test_imagefolder_decode Ended.\n')
  369. test_imagefolder_rename()
  370. logger.info('test_imagefolder_rename Ended.\n')
  371. test_imagefolder_zip()
  372. logger.info('test_imagefolder_zip Ended.\n')