You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_apply.py 6.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.vision.c_transforms as vision
  18. from mindspore import log as logger
  19. DATA_DIR = "../data/dataset/testPK/data"
  20. # Generate 1d int numpy array from 0 - 64
  21. def generator_1d():
  22. for i in range(64):
  23. yield (np.array([i]),)
  24. def test_apply_generator_case():
  25. # apply dataset operations
  26. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  27. data2 = ds.GeneratorDataset(generator_1d, ["data"])
  28. def dataset_fn(ds_):
  29. ds_ = ds_.repeat(2)
  30. return ds_.batch(4)
  31. data1 = data1.apply(dataset_fn)
  32. data2 = data2.repeat(2)
  33. data2 = data2.batch(4)
  34. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  35. assert np.array_equal(item1["data"], item2["data"])
  36. def test_apply_imagefolder_case():
  37. # apply dataset map operations
  38. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
  39. data2 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
  40. decode_op = vision.Decode()
  41. normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
  42. def dataset_fn(ds_):
  43. ds_ = ds_.map(operations=decode_op)
  44. ds_ = ds_.map(operations=normalize_op)
  45. ds_ = ds_.repeat(2)
  46. return ds_
  47. data1 = data1.apply(dataset_fn)
  48. data2 = data2.map(operations=decode_op)
  49. data2 = data2.map(operations=normalize_op)
  50. data2 = data2.repeat(2)
  51. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  52. assert np.array_equal(item1["image"], item2["image"])
  53. def test_apply_flow_case_0(id_=0):
  54. # apply control flow operations
  55. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  56. def dataset_fn(ds_):
  57. if id_ == 0:
  58. ds_ = ds_.batch(4)
  59. elif id_ == 1:
  60. ds_ = ds_.repeat(2)
  61. elif id_ == 2:
  62. ds_ = ds_.batch(4)
  63. ds_ = ds_.repeat(2)
  64. else:
  65. ds_ = ds_.shuffle(buffer_size=4)
  66. return ds_
  67. data1 = data1.apply(dataset_fn)
  68. num_iter = 0
  69. for _ in data1.create_dict_iterator():
  70. num_iter = num_iter + 1
  71. if id_ == 0:
  72. assert num_iter == 16
  73. elif id_ == 1:
  74. assert num_iter == 128
  75. elif id_ == 2:
  76. assert num_iter == 32
  77. else:
  78. assert num_iter == 64
  79. def test_apply_flow_case_1(id_=1):
  80. # apply control flow operations
  81. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  82. def dataset_fn(ds_):
  83. if id_ == 0:
  84. ds_ = ds_.batch(4)
  85. elif id_ == 1:
  86. ds_ = ds_.repeat(2)
  87. elif id_ == 2:
  88. ds_ = ds_.batch(4)
  89. ds_ = ds_.repeat(2)
  90. else:
  91. ds_ = ds_.shuffle(buffer_size=4)
  92. return ds_
  93. data1 = data1.apply(dataset_fn)
  94. num_iter = 0
  95. for _ in data1.create_dict_iterator():
  96. num_iter = num_iter + 1
  97. if id_ == 0:
  98. assert num_iter == 16
  99. elif id_ == 1:
  100. assert num_iter == 128
  101. elif id_ == 2:
  102. assert num_iter == 32
  103. else:
  104. assert num_iter == 64
  105. def test_apply_flow_case_2(id_=2):
  106. # apply control flow operations
  107. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  108. def dataset_fn(ds_):
  109. if id_ == 0:
  110. ds_ = ds_.batch(4)
  111. elif id_ == 1:
  112. ds_ = ds_.repeat(2)
  113. elif id_ == 2:
  114. ds_ = ds_.batch(4)
  115. ds_ = ds_.repeat(2)
  116. else:
  117. ds_ = ds_.shuffle(buffer_size=4)
  118. return ds_
  119. data1 = data1.apply(dataset_fn)
  120. num_iter = 0
  121. for _ in data1.create_dict_iterator():
  122. num_iter = num_iter + 1
  123. if id_ == 0:
  124. assert num_iter == 16
  125. elif id_ == 1:
  126. assert num_iter == 128
  127. elif id_ == 2:
  128. assert num_iter == 32
  129. else:
  130. assert num_iter == 64
  131. def test_apply_flow_case_3(id_=3):
  132. # apply control flow operations
  133. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  134. def dataset_fn(ds_):
  135. if id_ == 0:
  136. ds_ = ds_.batch(4)
  137. elif id_ == 1:
  138. ds_ = ds_.repeat(2)
  139. elif id_ == 2:
  140. ds_ = ds_.batch(4)
  141. ds_ = ds_.repeat(2)
  142. else:
  143. ds_ = ds_.shuffle(buffer_size=4)
  144. return ds_
  145. data1 = data1.apply(dataset_fn)
  146. num_iter = 0
  147. for _ in data1.create_dict_iterator():
  148. num_iter = num_iter + 1
  149. if id_ == 0:
  150. assert num_iter == 16
  151. elif id_ == 1:
  152. assert num_iter == 128
  153. elif id_ == 2:
  154. assert num_iter == 32
  155. else:
  156. assert num_iter == 64
  157. def test_apply_exception_case():
  158. # apply exception operations
  159. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  160. def dataset_fn(ds_):
  161. ds_ = ds_.repeat(2)
  162. return ds_.batch(4)
  163. def exception_fn():
  164. return np.array([[0], [1], [3], [4], [5]])
  165. try:
  166. data1 = data1.apply("123")
  167. for _ in data1.create_dict_iterator():
  168. pass
  169. assert False
  170. except TypeError:
  171. pass
  172. try:
  173. data1 = data1.apply(exception_fn)
  174. for _ in data1.create_dict_iterator():
  175. pass
  176. assert False
  177. except TypeError:
  178. pass
  179. try:
  180. data2 = data1.apply(dataset_fn)
  181. _ = data1.apply(dataset_fn)
  182. for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  183. pass
  184. assert False
  185. except ValueError as e:
  186. logger.info("Got an exception in DE: {}".format(str(e)))
  187. if __name__ == '__main__':
  188. logger.info("Running test_apply.py test_apply_generator_case() function")
  189. test_apply_generator_case()
  190. logger.info("Running test_apply.py test_apply_imagefolder_case() function")
  191. test_apply_imagefolder_case()
  192. logger.info("Running test_apply.py test_apply_flow_case(id) function")
  193. test_apply_flow_case_0()
  194. test_apply_flow_case_1()
  195. test_apply_flow_case_2()
  196. test_apply_flow_case_3()
  197. logger.info("Running test_apply.py test_apply_exception_case() function")
  198. test_apply_exception_case()