You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_cifarop.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Cifar10 and Cifar100 dataset operators
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. DATA_DIR_10 = "../data/dataset/testCifar10Data"
  25. DATA_DIR_100 = "../data/dataset/testCifar100Data"
  26. def load_cifar(path, kind="cifar10"):
  27. """
  28. load Cifar10/100 data
  29. """
  30. raw = np.empty(0, dtype=np.uint8)
  31. for file_name in os.listdir(path):
  32. if file_name.endswith(".bin"):
  33. with open(os.path.join(path, file_name), mode='rb') as file:
  34. raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0)
  35. if kind == "cifar10":
  36. raw = raw.reshape(-1, 3073)
  37. labels = raw[:, 0]
  38. images = raw[:, 1:]
  39. elif kind == "cifar100":
  40. raw = raw.reshape(-1, 3074)
  41. labels = raw[:, :2]
  42. images = raw[:, 2:]
  43. else:
  44. raise ValueError("Invalid parameter value")
  45. images = images.reshape(-1, 3, 32, 32)
  46. images = images.transpose(0, 2, 3, 1)
  47. return images, labels
  48. def visualize_dataset(images, labels):
  49. """
  50. Helper function to visualize the dataset samples
  51. """
  52. num_samples = len(images)
  53. for i in range(num_samples):
  54. plt.subplot(1, num_samples, i + 1)
  55. plt.imshow(images[i])
  56. plt.title(labels[i])
  57. plt.show()
  58. ### Testcases for Cifar10Dataset Op ###
  59. def test_cifar10_content_check():
  60. """
  61. Validate Cifar10Dataset image readings
  62. """
  63. logger.info("Test Cifar10Dataset Op with content check")
  64. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False)
  65. images, labels = load_cifar(DATA_DIR_10)
  66. num_iter = 0
  67. # in this example, each dictionary has keys "image" and "label"
  68. for i, d in enumerate(data1.create_dict_iterator()):
  69. np.testing.assert_array_equal(d["image"], images[i])
  70. np.testing.assert_array_equal(d["label"], labels[i])
  71. num_iter += 1
  72. assert num_iter == 100
  73. def test_cifar10_basic():
  74. """
  75. Validate CIFAR10
  76. """
  77. logger.info("Test Cifar10Dataset Op")
  78. # case 0: test loading the whole dataset
  79. data0 = ds.Cifar10Dataset(DATA_DIR_10)
  80. num_iter0 = 0
  81. for _ in data0.create_dict_iterator():
  82. num_iter0 += 1
  83. assert num_iter0 == 10000
  84. # case 1: test num_samples
  85. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  86. num_iter1 = 0
  87. for _ in data1.create_dict_iterator():
  88. num_iter1 += 1
  89. assert num_iter1 == 100
  90. # case 2: test num_parallel_workers
  91. data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=50, num_parallel_workers=1)
  92. num_iter2 = 0
  93. for _ in data2.create_dict_iterator():
  94. num_iter2 += 1
  95. assert num_iter2 == 50
  96. # case 3: test repeat
  97. data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  98. data3 = data3.repeat(3)
  99. num_iter3 = 0
  100. for _ in data3.create_dict_iterator():
  101. num_iter3 += 1
  102. assert num_iter3 == 300
  103. # case 4: test batch with drop_remainder=False
  104. data4 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  105. assert data4.get_dataset_size() == 100
  106. assert data4.get_batch_size() == 1
  107. data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
  108. assert data4.get_dataset_size() == 15
  109. assert data4.get_batch_size() == 7
  110. num_iter4 = 0
  111. for _ in data4.create_dict_iterator():
  112. num_iter4 += 1
  113. assert num_iter4 == 15
  114. # case 5: test batch with drop_remainder=True
  115. data5 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
  116. assert data5.get_dataset_size() == 100
  117. assert data5.get_batch_size() == 1
  118. data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
  119. assert data5.get_dataset_size() == 14
  120. assert data5.get_batch_size() == 7
  121. num_iter5 = 0
  122. for _ in data5.create_dict_iterator():
  123. num_iter5 += 1
  124. assert num_iter5 == 14
  125. def test_cifar10_pk_sampler():
  126. """
  127. Test Cifar10Dataset with PKSampler
  128. """
  129. logger.info("Test Cifar10Dataset Op with PKSampler")
  130. golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
  131. 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
  132. sampler = ds.PKSampler(3)
  133. data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  134. num_iter = 0
  135. label_list = []
  136. for item in data.create_dict_iterator():
  137. label_list.append(item["label"])
  138. num_iter += 1
  139. np.testing.assert_array_equal(golden, label_list)
  140. assert num_iter == 30
  141. def test_cifar10_sequential_sampler():
  142. """
  143. Test Cifar10Dataset with SequentialSampler
  144. """
  145. logger.info("Test Cifar10Dataset Op with SequentialSampler")
  146. num_samples = 30
  147. sampler = ds.SequentialSampler(num_samples=num_samples)
  148. data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
  149. data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples)
  150. num_iter = 0
  151. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  152. np.testing.assert_equal(item1["label"], item2["label"])
  153. num_iter += 1
  154. assert num_iter == num_samples
  155. def test_cifar10_exception():
  156. """
  157. Test error cases for Cifar10Dataset
  158. """
  159. logger.info("Test error cases for Cifar10Dataset")
  160. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  161. with pytest.raises(RuntimeError, match=error_msg_1):
  162. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3))
  163. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  164. with pytest.raises(RuntimeError, match=error_msg_2):
  165. ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  166. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  167. with pytest.raises(RuntimeError, match=error_msg_3):
  168. ds.Cifar10Dataset(DATA_DIR_10, num_shards=10)
  169. error_msg_4 = "shard_id is specified but num_shards is not"
  170. with pytest.raises(RuntimeError, match=error_msg_4):
  171. ds.Cifar10Dataset(DATA_DIR_10, shard_id=0)
  172. error_msg_5 = "Input shard_id is not within the required interval"
  173. with pytest.raises(ValueError, match=error_msg_5):
  174. ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1)
  175. with pytest.raises(ValueError, match=error_msg_5):
  176. ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5)
  177. error_msg_6 = "num_parallel_workers exceeds"
  178. with pytest.raises(ValueError, match=error_msg_6):
  179. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0)
  180. with pytest.raises(ValueError, match=error_msg_6):
  181. ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88)
  182. def test_cifar10_visualize(plot=False):
  183. """
  184. Visualize Cifar10Dataset results
  185. """
  186. logger.info("Test Cifar10Dataset visualization")
  187. data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False)
  188. num_iter = 0
  189. image_list, label_list = [], []
  190. for item in data1.create_dict_iterator():
  191. image = item["image"]
  192. label = item["label"]
  193. image_list.append(image)
  194. label_list.append("label {}".format(label))
  195. assert isinstance(image, np.ndarray)
  196. assert image.shape == (32, 32, 3)
  197. assert image.dtype == np.uint8
  198. assert label.dtype == np.uint32
  199. num_iter += 1
  200. assert num_iter == 10
  201. if plot:
  202. visualize_dataset(image_list, label_list)
  203. ### Testcases for Cifar100Dataset Op ###
  204. def test_cifar100_content_check():
  205. """
  206. Validate Cifar100Dataset image readings
  207. """
  208. logger.info("Test Cifar100Dataset with content check")
  209. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False)
  210. images, labels = load_cifar(DATA_DIR_100, kind="cifar100")
  211. num_iter = 0
  212. # in this example, each dictionary has keys "image", "coarse_label" and "fine_image"
  213. for i, d in enumerate(data1.create_dict_iterator()):
  214. np.testing.assert_array_equal(d["image"], images[i])
  215. np.testing.assert_array_equal(d["coarse_label"], labels[i][0])
  216. np.testing.assert_array_equal(d["fine_label"], labels[i][1])
  217. num_iter += 1
  218. assert num_iter == 100
  219. def test_cifar100_basic():
  220. """
  221. Test Cifar100Dataset
  222. """
  223. logger.info("Test Cifar100Dataset")
  224. # case 1: test num_samples
  225. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  226. num_iter1 = 0
  227. for _ in data1.create_dict_iterator():
  228. num_iter1 += 1
  229. assert num_iter1 == 100
  230. # case 2: test repeat
  231. data1 = data1.repeat(2)
  232. num_iter2 = 0
  233. for _ in data1.create_dict_iterator():
  234. num_iter2 += 1
  235. assert num_iter2 == 200
  236. # case 3: test num_parallel_workers
  237. data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, num_parallel_workers=1)
  238. num_iter3 = 0
  239. for _ in data2.create_dict_iterator():
  240. num_iter3 += 1
  241. assert num_iter3 == 100
  242. # case 4: test batch with drop_remainder=False
  243. data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  244. assert data3.get_dataset_size() == 100
  245. assert data3.get_batch_size() == 1
  246. data3 = data3.batch(batch_size=3)
  247. assert data3.get_dataset_size() == 34
  248. assert data3.get_batch_size() == 3
  249. num_iter4 = 0
  250. for _ in data3.create_dict_iterator():
  251. num_iter4 += 1
  252. assert num_iter4 == 34
  253. # case 4: test batch with drop_remainder=True
  254. data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
  255. data4 = data4.batch(batch_size=3, drop_remainder=True)
  256. assert data4.get_dataset_size() == 33
  257. assert data4.get_batch_size() == 3
  258. num_iter5 = 0
  259. for _ in data4.create_dict_iterator():
  260. num_iter5 += 1
  261. assert num_iter5 == 33
  262. def test_cifar100_pk_sampler():
  263. """
  264. Test Cifar100Dataset with PKSampler
  265. """
  266. logger.info("Test Cifar100Dataset with PKSampler")
  267. golden = [i for i in range(20)]
  268. sampler = ds.PKSampler(1)
  269. data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler)
  270. num_iter = 0
  271. label_list = []
  272. for item in data.create_dict_iterator():
  273. label_list.append(item["coarse_label"])
  274. num_iter += 1
  275. np.testing.assert_array_equal(golden, label_list)
  276. assert num_iter == 20
  277. def test_cifar100_exception():
  278. """
  279. Test error cases for Cifar100Dataset
  280. """
  281. logger.info("Test error cases for Cifar100Dataset")
  282. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  283. with pytest.raises(RuntimeError, match=error_msg_1):
  284. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3))
  285. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  286. with pytest.raises(RuntimeError, match=error_msg_2):
  287. ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  288. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  289. with pytest.raises(RuntimeError, match=error_msg_3):
  290. ds.Cifar100Dataset(DATA_DIR_100, num_shards=10)
  291. error_msg_4 = "shard_id is specified but num_shards is not"
  292. with pytest.raises(RuntimeError, match=error_msg_4):
  293. ds.Cifar100Dataset(DATA_DIR_100, shard_id=0)
  294. error_msg_5 = "Input shard_id is not within the required interval"
  295. with pytest.raises(ValueError, match=error_msg_5):
  296. ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1)
  297. with pytest.raises(ValueError, match=error_msg_5):
  298. ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5)
  299. error_msg_6 = "num_parallel_workers exceeds"
  300. with pytest.raises(ValueError, match=error_msg_6):
  301. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=0)
  302. with pytest.raises(ValueError, match=error_msg_6):
  303. ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88)
  304. def test_cifar100_visualize(plot=False):
  305. """
  306. Visualize Cifar100Dataset results
  307. """
  308. logger.info("Test Cifar100Dataset visualization")
  309. data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False)
  310. num_iter = 0
  311. image_list, label_list = [], []
  312. for item in data1.create_dict_iterator():
  313. image = item["image"]
  314. coarse_label = item["coarse_label"]
  315. fine_label = item["fine_label"]
  316. image_list.append(image)
  317. label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label))
  318. assert isinstance(image, np.ndarray)
  319. assert image.shape == (32, 32, 3)
  320. assert image.dtype == np.uint8
  321. assert coarse_label.dtype == np.uint32
  322. assert fine_label.dtype == np.uint32
  323. num_iter += 1
  324. assert num_iter == 10
  325. if plot:
  326. visualize_dataset(image_list, label_list)
  327. if __name__ == '__main__':
  328. test_cifar10_content_check()
  329. test_cifar10_basic()
  330. test_cifar10_pk_sampler()
  331. test_cifar10_sequential_sampler()
  332. test_cifar10_exception()
  333. test_cifar10_visualize(plot=False)
  334. test_cifar100_content_check()
  335. test_cifar100_basic()
  336. test_cifar100_pk_sampler()
  337. test_cifar100_exception()
  338. test_cifar100_visualize(plot=False)