You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_sharding.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. def test_imagefolder_shardings(print_res=False):
  18. image_folder_dir = "../data/dataset/testPK/data"
  19. def sharding_config(num_shards, shard_id, num_samples, shuffle, class_index, repeat_cnt=1):
  20. data1 = ds.ImageFolderDatasetV2(image_folder_dir, num_samples=num_samples, num_shards=num_shards,
  21. shard_id=shard_id,
  22. shuffle=shuffle, class_indexing=class_index, decode=True)
  23. data1 = data1.repeat(repeat_cnt)
  24. res = []
  25. for item in data1.create_dict_iterator(): # each data is a dictionary
  26. res.append(item["label"].item())
  27. if print_res:
  28. logger.info("labels of dataset: {}".format(res))
  29. return res
  30. # total 44 rows in dataset
  31. assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows
  32. assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows
  33. assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows
  34. assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]) # 44 rows
  35. assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows
  36. assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows
  37. # total 22 in dataset rows because of class indexing which takes only 2 folders
  38. assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6
  39. assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3
  40. # test with repeat
  41. assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
  42. assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
  43. assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20
  44. def test_tfrecord_shardings1(print_res=False):
  45. """ Test TFRecordDataset sharding with num_parallel_workers=1 """
  46. # total 40 rows in dataset
  47. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  48. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  49. def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1):
  50. data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
  51. shuffle=ds.Shuffle.FILES, num_parallel_workers=1)
  52. data1 = data1.repeat(repeat_cnt)
  53. res = []
  54. for item in data1.create_dict_iterator(): # each data is a dictionary
  55. res.append(item["scalars"][0])
  56. if print_res:
  57. logger.info("scalars of dataset: {}".format(res))
  58. return res
  59. assert sharding_config(2, 0, None, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows
  60. assert sharding_config(2, 1, None, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows
  61. assert sharding_config(2, 0, 3, 1) == [11, 12, 13] # 3 rows
  62. assert sharding_config(2, 1, 3, 1) == [1, 2, 3] # 3 rows
  63. assert sharding_config(2, 0, 40, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows
  64. assert sharding_config(2, 1, 40, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows
  65. assert sharding_config(2, 0, 55, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows
  66. assert sharding_config(2, 1, 55, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows
  67. assert sharding_config(3, 0, 8, 1) == [11, 12, 13, 14, 15, 16, 17, 18] # 8 rows
  68. assert sharding_config(3, 1, 8, 1) == [1, 2, 3, 4, 5, 6, 7, 8] # 8 rows
  69. assert sharding_config(3, 2, 8, 1) == [21, 22, 23, 24, 25, 26, 27, 28] # 8 rows
  70. assert sharding_config(4, 0, 2, 1) == [11, 12] # 2 rows
  71. assert sharding_config(4, 1, 2, 1) == [1, 2] # 2 rows
  72. assert sharding_config(4, 2, 2, 1) == [21, 22] # 2 rows
  73. assert sharding_config(4, 3, 2, 1) == [31, 32] # 2 rows
  74. assert sharding_config(3, 0, 4, 2) == [11, 12, 13, 14, 21, 22, 23, 24] # 8 rows
  75. assert sharding_config(3, 1, 4, 2) == [1, 2, 3, 4, 11, 12, 13, 14] # 8 rows
  76. assert sharding_config(3, 2, 4, 2) == [21, 22, 23, 24, 31, 32, 33, 34] # 8 rows
  77. def test_tfrecord_shardings4(print_res=False):
  78. """ Test TFRecordDataset sharding with num_parallel_workers=4 """
  79. # total 40 rows in dataset
  80. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  81. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  82. def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1):
  83. data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
  84. shuffle=ds.Shuffle.FILES, num_parallel_workers=4)
  85. data1 = data1.repeat(repeat_cnt)
  86. res = []
  87. for item in data1.create_dict_iterator(): # each data is a dictionary
  88. res.append(item["scalars"][0])
  89. if print_res:
  90. logger.info("scalars of dataset: {}".format(res))
  91. return res
  92. def check_result(result_list, expect_length, expect_set):
  93. assert len(result_list) == expect_length
  94. assert set(result_list) == expect_set
  95. check_result(sharding_config(2, 0, None, 1), 20,
  96. {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
  97. check_result(sharding_config(2, 1, None, 1), 20,
  98. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
  99. check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21})
  100. check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31})
  101. check_result(sharding_config(2, 0, 40, 1), 20,
  102. {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
  103. check_result(sharding_config(2, 1, 40, 1), 20,
  104. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
  105. check_result(sharding_config(2, 0, 55, 1), 20,
  106. {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
  107. check_result(sharding_config(2, 1, 55, 1), 20,
  108. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
  109. check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31})
  110. check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8})
  111. check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28})
  112. check_result(sharding_config(4, 0, 2, 1), 2, {11, 12})
  113. check_result(sharding_config(4, 1, 2, 1), 2, {1, 2})
  114. check_result(sharding_config(4, 2, 2, 1), 2, {21, 22})
  115. check_result(sharding_config(4, 3, 2, 1), 2, {31, 32})
  116. check_result(sharding_config(3, 0, 4, 2), 8, {32, 1, 2, 11, 12, 21, 22, 31})
  117. check_result(sharding_config(3, 1, 4, 2), 8, {1, 2, 3, 4, 11, 12, 13, 14})
  118. check_result(sharding_config(3, 2, 4, 2), 8, {32, 33, 34, 21, 22, 23, 24, 31})
  119. def test_manifest_shardings(print_res=False):
  120. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  121. def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
  122. data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, num_shards=num_shards, shard_id=shard_id,
  123. shuffle=shuffle, decode=True)
  124. data1 = data1.repeat(repeat_cnt)
  125. res = []
  126. for item in data1.create_dict_iterator(): # each data is a dictionary
  127. res.append(item["label"].item())
  128. if print_res:
  129. logger.info("labels of dataset: {}".format(res))
  130. return res
  131. # 5 train images in total
  132. sharding_config(2, 0, None, False)
  133. assert (sharding_config(2, 0, None, False) == [0, 1, 1])
  134. assert (sharding_config(2, 1, None, False) == [0, 0, 0])
  135. assert (sharding_config(2, 0, 2, False) == [0, 1])
  136. assert (sharding_config(2, 1, 2, False) == [0, 0])
  137. # with repeat
  138. assert (sharding_config(2, 1, None, False, 3) == [0, 0, 0] * 3)
  139. assert (sharding_config(2, 0, 2, False, 5) == [0, 1] * 5)
  140. def test_voc_shardings(print_res=False):
  141. voc_dir = "../data/dataset/testVOC2012"
  142. def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
  143. sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  144. data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler)
  145. data1 = data1.repeat(repeat_cnt)
  146. res = []
  147. for item in data1.create_dict_iterator(): # each data is a dictionary
  148. res.append(item["image"].shape[0])
  149. if print_res:
  150. logger.info("labels of dataset: {}".format(res))
  151. return res
  152. # 10 images in total, always decode to get the shape
  153. # first dim of all 10 images [2268,2268,2268,2268,642,607,561,596,612,2268]
  154. # 3 shard_workers, 0th worker will get 0-th, 3nd, 6th and 9th image
  155. assert (sharding_config(3, 0, None, False, 2) == [2268, 2268, 561, 2268] * 2)
  156. # 3 shard_workers, 1st worker will get 1-st, 4nd, 7th and 0th image, the last one goes back bc of rounding up
  157. assert (sharding_config(3, 1, 5, False, 3) == [2268, 642, 596, 2268] * 3)
  158. # 3 shard_workers, 2nd worker will get 2nd, 5th, 8th and 11th (which is 1st)
  159. # then takes the first 2 bc num_samples = 2
  160. assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4)
  161. # test that each epoch, each shard_worker returns a different sample
  162. assert len(sharding_config(2, 0, None, True, 1)) == 5
  163. assert len(set(sharding_config(11, 0, None, True, 10))) > 1
  164. def test_cifar10_shardings(print_res=False):
  165. cifar10_dir = "../data/dataset/testCifar10Data"
  166. def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
  167. data1 = ds.Cifar10Dataset(cifar10_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
  168. shuffle=shuffle)
  169. data1 = data1.repeat(repeat_cnt)
  170. res = []
  171. for item in data1.create_dict_iterator(): # each data is a dictionary
  172. res.append(item["label"].item())
  173. if print_res:
  174. logger.info("labels of dataset: {}".format(res))
  175. return res
  176. # 10000 rows in total. CIFAR reads everything in memory which would make each test case very slow
  177. # therefore, only 2 test cases for now.
  178. assert sharding_config(10000, 9999, 7, False, 1) == [9]
  179. assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]
  180. def test_cifar100_shardings(print_res=False):
  181. cifar100_dir = "../data/dataset/testCifar100Data"
  182. def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
  183. data1 = ds.Cifar100Dataset(cifar100_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
  184. shuffle=shuffle)
  185. data1 = data1.repeat(repeat_cnt)
  186. res = []
  187. for item in data1.create_dict_iterator(): # each data is a dictionary
  188. res.append(item["coarse_label"].item())
  189. if print_res:
  190. logger.info("labels of dataset: {}".format(res))
  191. return res
  192. # 10000 rows in total in test.bin CIFAR100 file
  193. assert (sharding_config(1000, 999, 7, False, 2) == [1, 18, 10, 17, 5, 0, 15] * 2)
  194. assert (sharding_config(1000, 0, None, False) == [10, 16, 2, 11, 10, 17, 11, 14, 13, 3])
  195. def test_mnist_shardings(print_res=False):
  196. mnist_dir = "../data/dataset/testMnistData"
  197. def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
  198. data1 = ds.MnistDataset(mnist_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
  199. shuffle=shuffle)
  200. data1 = data1.repeat(repeat_cnt)
  201. res = []
  202. for item in data1.create_dict_iterator(): # each data is a dictionary
  203. res.append(item["label"].item())
  204. if print_res:
  205. logger.info("labels of dataset: {}".format(res))
  206. return res
  207. # 70K rows in total , divide across 10K hosts, each host has 7 images
  208. assert sharding_config(10000, 0, num_samples=5, shuffle=False, repeat_cnt=3) == [0, 0, 0]
  209. assert sharding_config(10000, 9999, num_samples=None, shuffle=False, repeat_cnt=1) == [9]
  210. if __name__ == '__main__':
  211. test_imagefolder_shardings(True)
  212. test_tfrecord_shardings1(True)
  213. test_tfrecord_shardings4(True)
  214. test_manifest_shardings(True)
  215. test_voc_shardings(True)
  216. test_cifar10_shardings(True)
  217. test_cifar100_shardings(True)
  218. test_mnist_shardings(True)