You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_take.py 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. # In generator dataset: Number of rows is 3, its value is 0, 1, 2
  19. def generator():
  20. for i in range(3):
  21. yield (np.array([i]),)
  22. # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
  23. def generator_10():
  24. for i in range(10):
  25. yield (np.array([i]),)
  26. def filter_func_ge(data):
  27. if data > 3:
  28. return False
  29. return True
  30. def test_take_01():
  31. """
  32. Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
  33. """
  34. logger.info("test_take_01")
  35. data1 = ds.GeneratorDataset(generator, ["data"])
  36. data1 = data1.take(1)
  37. data1 = data1.repeat(2)
  38. # Here i refers to index, d refers to data element
  39. for _, d in enumerate(data1):
  40. assert d[0][0] == 0
  41. assert sum([1 for _ in data1]) == 2
  42. def test_take_02():
  43. """
  44. Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
  45. """
  46. logger.info("test_take_02")
  47. data1 = ds.GeneratorDataset(generator, ["data"])
  48. data1 = data1.take(2)
  49. data1 = data1.repeat(2)
  50. # Here i refers to index, d refers to data element
  51. for i, d in enumerate(data1):
  52. assert i % 2 == d[0][0]
  53. assert sum([1 for _ in data1]) == 4
  54. def test_take_03():
  55. """
  56. Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
  57. """
  58. logger.info("test_take_03")
  59. data1 = ds.GeneratorDataset(generator, ["data"])
  60. data1 = data1.take(3)
  61. data1 = data1.repeat(2)
  62. # Here i refers to index, d refers to data elements
  63. for i, d in enumerate(data1):
  64. assert i % 3 == d[0][0]
  65. assert sum([1 for _ in data1]) == 6
  66. def test_take_04():
  67. """
  68. Test take: origin there are 3 row, and take 4 row, this is more than the total rows
  69. """
  70. logger.info("test_take_04")
  71. data1 = ds.GeneratorDataset(generator, ["data"])
  72. data1 = data1.take(4)
  73. data1 = data1.repeat(2)
  74. # Here i refers to index, d refers to data element
  75. for i, d in enumerate(data1):
  76. assert i % 3 == d[0][0]
  77. assert sum([1 for _ in data1]) == 6
  78. def test_take_05():
  79. """
  80. Test take: there is no repeat op
  81. """
  82. logger.info("test_take_05")
  83. data1 = ds.GeneratorDataset(generator, ["data"])
  84. data1 = data1.take(2)
  85. # Here i refers to index, d refers to data element
  86. for i, d in enumerate(data1):
  87. assert i == d[0][0]
  88. assert sum([1 for _ in data1]) == 2
  89. def test_take_06():
  90. """
  91. Test take: repeat is before take
  92. """
  93. logger.info("test_take_06")
  94. data1 = ds.GeneratorDataset(generator, ["data"])
  95. data1 = data1.repeat(2)
  96. data1 = data1.take(4)
  97. # Here i refers to index, d refers to data element
  98. for i, d in enumerate(data1):
  99. assert i % 3 == d[0][0]
  100. assert sum([1 for _ in data1]) == 4
  101. def test_take_07():
  102. """
  103. Test take: take is before batch, that mean take(N), N refer to rows num
  104. """
  105. logger.info("test_take_07")
  106. data1 = ds.GeneratorDataset(generator, ["data"])
  107. data1 = data1.take(2)
  108. data1 = data1.batch(2)
  109. assert sum([1 for _ in data1]) == 1
  110. def test_take_08():
  111. """
  112. Test take: take is after batch, that mean take(N), N refer to batches num
  113. """
  114. logger.info("test_take_08")
  115. data1 = ds.GeneratorDataset(generator, ["data"])
  116. data1 = data1.batch(2)
  117. data1 = data1.take(2)
  118. assert sum([1 for _ in data1]) == 2
  119. def test_take_09():
  120. """
  121. Test take: repeat count is -1, and read the whole dataset, take after repeat
  122. """
  123. logger.info("test_take_09")
  124. data1 = ds.GeneratorDataset(generator, ["data"])
  125. data1 = data1.repeat(2)
  126. data1 = data1.take(-1)
  127. # Here i refers to index, d refers to data element
  128. for i, d in enumerate(data1):
  129. assert i % 3 == d[0][0]
  130. assert sum([1 for _ in data1]) == 6
  131. def test_take_10():
  132. """
  133. Test take: repeat count is -1, and read the whole dataset, take before repeat
  134. """
  135. logger.info("test_take_10")
  136. data1 = ds.GeneratorDataset(generator, ["data"])
  137. data1 = data1.take(-1)
  138. data1 = data1.repeat(2)
  139. # Here i refers to index, d refers to data element
  140. for i, d in enumerate(data1):
  141. assert i % 3 == d[0][0]
  142. assert sum([1 for _ in data1]) == 6
  143. def test_take_11():
  144. """
  145. Test take: batch first, then do repeat and take operation
  146. """
  147. logger.info("test_take_11")
  148. data1 = ds.GeneratorDataset(generator, ["data"])
  149. data1 = data1.batch(2)
  150. data1 = data1.repeat(2)
  151. data1 = data1.take(-1)
  152. # Here i refers to index, d refers to data element
  153. for i, d in enumerate(data1):
  154. assert 2 * (i % 2) == d[0][0]
  155. assert sum([1 for _ in data1]) == 4
  156. def test_take_12():
  157. """
  158. Test take: take first, then do batch and repeat operation
  159. """
  160. logger.info("test_take_12")
  161. data1 = ds.GeneratorDataset(generator, ["data"])
  162. data1 = data1.take(2)
  163. data1 = data1.batch(2)
  164. data1 = data1.repeat(2)
  165. # Here i refers to index, d refers to data element
  166. for _, d in enumerate(data1):
  167. assert d[0][0] == 0
  168. assert sum([1 for _ in data1]) == 2
  169. def test_take_13():
  170. """
  171. Test take: skip first, then do take, batch and repeat operation
  172. """
  173. logger.info("test_take_13")
  174. data1 = ds.GeneratorDataset(generator, ["data"])
  175. data1 = data1.skip(2)
  176. data1 = data1.take(-1)
  177. data1 = data1.batch(2)
  178. data1 = data1.repeat(2)
  179. # Here i refers to index, d refers to data element
  180. for _, d in enumerate(data1):
  181. assert d[0][0] == 2
  182. assert sum([1 for _ in data1]) == 2
  183. def test_take_14():
  184. """
  185. Test take: take first, then do batch, skip and repeat operation
  186. """
  187. logger.info("test_take_14")
  188. data1 = ds.GeneratorDataset(generator, ["data"])
  189. data1 = data1.take(-1)
  190. data1 = data1.batch(2)
  191. data1 = data1.skip(1)
  192. data1 = data1.repeat(2)
  193. # Here i refers to index, d refers to data element
  194. for _, d in enumerate(data1):
  195. assert d[0][0] == 2
  196. assert sum([1 for _ in data1]) == 2
  197. def test_take_15():
  198. """
  199. Test take: large amount data, take a part, then do skip operation
  200. """
  201. logger.info("test_take_15")
  202. data1 = ds.GeneratorDataset(generator_10, ["data"])
  203. data1 = data1.take(6)
  204. data1 = data1.skip(2)
  205. # Here i refers to index, d refers to data element
  206. for i, d in enumerate(data1):
  207. assert (i + 2) == d[0][0]
  208. assert sum([1 for _ in data1]) == 4
  209. def test_take_16():
  210. """
  211. Test take: large amount data, skip a part, then do take operation
  212. """
  213. logger.info("test_take_16")
  214. data1 = ds.GeneratorDataset(generator_10, ["data"])
  215. data1 = data1.skip(3)
  216. data1 = data1.take(5)
  217. # Here i refers to index, d refers to data element
  218. for i, d in enumerate(data1):
  219. assert (i + 3) == d[0][0]
  220. assert sum([1 for _ in data1]) == 5
  221. def test_take_17():
  222. """
  223. Test take: take first, then do fiter operation
  224. """
  225. logger.info("test_take_17")
  226. data1 = ds.GeneratorDataset(generator_10, ["data"])
  227. data1 = data1.take(8)
  228. data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
  229. # Here i refers to index, d refers to data element
  230. for i, d in enumerate(data1):
  231. assert i == d[0][0]
  232. assert sum([1 for _ in data1]) == 4
  233. def test_take_18():
  234. """
  235. Test take: take first, then do fiter, skip, batch and repeat operation
  236. """
  237. logger.info("test_take_18")
  238. data1 = ds.GeneratorDataset(generator_10, ["data"])
  239. data1 = data1.take(8)
  240. data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
  241. data1 = data1.skip(2)
  242. data1 = data1.batch(2)
  243. data1 = data1.repeat(2)
  244. # Here i refers to index, d refers to data element
  245. for _, d in enumerate(data1):
  246. assert d[0][0] == 2
  247. assert sum([1 for _ in data1]) == 2
  248. if __name__ == '__main__':
  249. test_take_01()
  250. test_take_02()
  251. test_take_03()
  252. test_take_04()
  253. test_take_05()
  254. test_take_06()
  255. test_take_07()
  256. test_take_08()
  257. test_take_09()
  258. test_take_10()
  259. test_take_11()
  260. test_take_12()
  261. test_take_13()
  262. test_take_14()
  263. test_take_15()
  264. test_take_16()
  265. test_take_17()
  266. test_take_18()
  267. logger.info('== test take operation finished ==')