You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_mnist.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Mnist dataset operators
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. DATA_DIR = "../data/dataset/testMnistData"
  25. def load_mnist(path):
  26. """
  27. load Mnist data
  28. """
  29. labels_path = os.path.join(path, 't10k-labels-idx1-ubyte')
  30. images_path = os.path.join(path, 't10k-images-idx3-ubyte')
  31. with open(labels_path, 'rb') as lbpath:
  32. lbpath.read(8)
  33. labels = np.fromfile(lbpath, dtype=np.uint8)
  34. with open(images_path, 'rb') as imgpath:
  35. imgpath.read(16)
  36. images = np.fromfile(imgpath, dtype=np.uint8)
  37. images = images.reshape(-1, 28, 28, 1)
  38. images[images > 0] = 255 # Perform binarization to maintain consistency with our API
  39. return images, labels
  40. def visualize_dataset(images, labels):
  41. """
  42. Helper function to visualize the dataset samples
  43. """
  44. num_samples = len(images)
  45. for i in range(num_samples):
  46. plt.subplot(1, num_samples, i + 1)
  47. plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
  48. plt.title(labels[i])
  49. plt.show()
  50. def test_mnist_content_check():
  51. """
  52. Validate MnistDataset image readings
  53. """
  54. logger.info("Test MnistDataset Op with content check")
  55. data1 = ds.MnistDataset(DATA_DIR, num_samples=100, shuffle=False)
  56. images, labels = load_mnist(DATA_DIR)
  57. num_iter = 0
  58. # in this example, each dictionary has keys "image" and "label"
  59. image_list, label_list = [], []
  60. for i, data in enumerate(data1.create_dict_iterator()):
  61. image_list.append(data["image"])
  62. label_list.append("label {}".format(data["label"]))
  63. np.testing.assert_array_equal(data["image"], images[i])
  64. np.testing.assert_array_equal(data["label"], labels[i])
  65. num_iter += 1
  66. assert num_iter == 100
  67. def test_mnist_basic():
  68. """
  69. Validate MnistDataset
  70. """
  71. logger.info("Test MnistDataset Op")
  72. # case 1: test loading whole dataset
  73. data1 = ds.MnistDataset(DATA_DIR)
  74. num_iter1 = 0
  75. for _ in data1.create_dict_iterator():
  76. num_iter1 += 1
  77. assert num_iter1 == 10000
  78. # case 2: test num_samples
  79. data2 = ds.MnistDataset(DATA_DIR, num_samples=500)
  80. num_iter2 = 0
  81. for _ in data2.create_dict_iterator():
  82. num_iter2 += 1
  83. assert num_iter2 == 500
  84. # case 3: test repeat
  85. data3 = ds.MnistDataset(DATA_DIR, num_samples=200)
  86. data3 = data3.repeat(5)
  87. num_iter3 = 0
  88. for _ in data3.create_dict_iterator():
  89. num_iter3 += 1
  90. assert num_iter3 == 1000
  91. # case 4: test batch with drop_remainder=False
  92. data4 = ds.MnistDataset(DATA_DIR, num_samples=100)
  93. assert data4.get_dataset_size() == 100
  94. assert data4.get_batch_size() == 1
  95. data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
  96. assert data4.get_dataset_size() == 15
  97. assert data4.get_batch_size() == 7
  98. num_iter4 = 0
  99. for _ in data4.create_dict_iterator():
  100. num_iter4 += 1
  101. assert num_iter4 == 15
  102. # case 5: test batch with drop_remainder=True
  103. data5 = ds.MnistDataset(DATA_DIR, num_samples=100)
  104. assert data5.get_dataset_size() == 100
  105. assert data5.get_batch_size() == 1
  106. data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
  107. assert data5.get_dataset_size() == 14
  108. assert data5.get_batch_size() == 7
  109. num_iter5 = 0
  110. for _ in data5.create_dict_iterator():
  111. num_iter5 += 1
  112. assert num_iter5 == 14
  113. def test_mnist_pk_sampler():
  114. """
  115. Test MnistDataset with PKSampler
  116. """
  117. logger.info("Test MnistDataset Op with PKSampler")
  118. golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
  119. 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
  120. sampler = ds.PKSampler(3)
  121. data = ds.MnistDataset(DATA_DIR, sampler=sampler)
  122. num_iter = 0
  123. label_list = []
  124. for item in data.create_dict_iterator():
  125. label_list.append(item["label"])
  126. num_iter += 1
  127. np.testing.assert_array_equal(golden, label_list)
  128. assert num_iter == 30
  129. def test_mnist_sequential_sampler():
  130. """
  131. Test MnistDataset with SequentialSampler
  132. """
  133. logger.info("Test MnistDataset Op with SequentialSampler")
  134. num_samples = 50
  135. sampler = ds.SequentialSampler(num_samples=num_samples)
  136. data1 = ds.MnistDataset(DATA_DIR, sampler=sampler)
  137. data2 = ds.MnistDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
  138. label_list1, label_list2 = [], []
  139. num_iter = 0
  140. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  141. label_list1.append(item1["label"])
  142. label_list2.append(item2["label"])
  143. num_iter += 1
  144. np.testing.assert_array_equal(label_list1, label_list2)
  145. assert num_iter == num_samples
  146. def test_mnist_exception():
  147. """
  148. Test error cases for MnistDataset
  149. """
  150. logger.info("Test error cases for MnistDataset")
  151. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  152. with pytest.raises(RuntimeError, match=error_msg_1):
  153. ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  154. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  155. with pytest.raises(RuntimeError, match=error_msg_2):
  156. ds.MnistDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  157. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  158. with pytest.raises(RuntimeError, match=error_msg_3):
  159. ds.MnistDataset(DATA_DIR, num_shards=10)
  160. error_msg_4 = "shard_id is specified but num_shards is not"
  161. with pytest.raises(RuntimeError, match=error_msg_4):
  162. ds.MnistDataset(DATA_DIR, shard_id=0)
  163. error_msg_5 = "Input shard_id is not within the required interval"
  164. with pytest.raises(ValueError, match=error_msg_5):
  165. ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1)
  166. with pytest.raises(ValueError, match=error_msg_5):
  167. ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5)
  168. with pytest.raises(ValueError, match=error_msg_5):
  169. ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5)
  170. error_msg_6 = "num_parallel_workers exceeds"
  171. with pytest.raises(ValueError, match=error_msg_6):
  172. ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  173. with pytest.raises(ValueError, match=error_msg_6):
  174. ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=65)
  175. with pytest.raises(ValueError, match=error_msg_6):
  176. ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  177. error_msg_7 = "Argument shard_id"
  178. with pytest.raises(TypeError, match=error_msg_7):
  179. ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0")
  180. def test_mnist_visualize(plot=False):
  181. """
  182. Visualize MnistDataset results
  183. """
  184. logger.info("Test MnistDataset visualization")
  185. data1 = ds.MnistDataset(DATA_DIR, num_samples=10, shuffle=False)
  186. num_iter = 0
  187. image_list, label_list = [], []
  188. for item in data1.create_dict_iterator():
  189. image = item["image"]
  190. label = item["label"]
  191. image_list.append(image)
  192. label_list.append("label {}".format(label))
  193. assert isinstance(image, np.ndarray)
  194. assert image.shape == (28, 28, 1)
  195. assert image.dtype == np.uint8
  196. assert label.dtype == np.uint32
  197. num_iter += 1
  198. assert num_iter == 10
  199. if plot:
  200. visualize_dataset(image_list, label_list)
  201. if __name__ == '__main__':
  202. test_mnist_content_check()
  203. test_mnist_basic()
  204. test_mnist_pk_sampler()
  205. test_mnist_sequential_sampler()
  206. test_mnist_exception()
  207. test_mnist_visualize(plot=True)