You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_concat.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.common.dtype as mstype
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.transforms.c_transforms as C
  19. import mindspore.dataset.transforms.vision.py_transforms as F
  20. from mindspore import log as logger
  21. # In generator dataset: Number of rows is 3; its values are 0, 1, 2
  22. def generator():
  23. for i in range(3):
  24. yield (np.array([i]),)
  25. # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
  26. def generator_10():
  27. for i in range(3, 10):
  28. yield (np.array([i]),)
  29. # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
  30. def generator_20():
  31. for i in range(10, 20):
  32. yield (np.array([i]),)
  33. def test_concat_01():
  34. """
  35. Test concat: test concat 2 datasets that have the same column name and data type
  36. """
  37. logger.info("test_concat_01")
  38. data1 = ds.GeneratorDataset(generator, ["col1"])
  39. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  40. data3 = data1 + data2
  41. # Here i refers to index, d refers to data element
  42. for i, d in enumerate(data3):
  43. logger.info("data: %i", d[0][0])
  44. assert i == d[0][0]
  45. assert sum([1 for _ in data3]) == 10
  46. def test_concat_02():
  47. """
  48. Test concat: test concat 2 datasets using concat operation not "+" operation
  49. """
  50. logger.info("test_concat_02")
  51. data1 = ds.GeneratorDataset(generator, ["col1"])
  52. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  53. data3 = data1.concat(data2)
  54. # Here i refers to index, d refers to data element
  55. for i, d in enumerate(data3):
  56. logger.info("data: %i", d[0][0])
  57. assert i == d[0][0]
  58. assert sum([1 for _ in data3]) == 10
  59. def test_concat_03():
  60. """
  61. Test concat: test concat dataset that has different column
  62. """
  63. logger.info("test_concat_03")
  64. data1 = ds.GeneratorDataset(generator, ["col1"])
  65. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  66. data3 = data1 + data2
  67. try:
  68. for _, _ in enumerate(data3):
  69. pass
  70. assert False
  71. except RuntimeError:
  72. pass
  73. def test_concat_04():
  74. """
  75. Test concat: test concat dataset that has different rank
  76. """
  77. logger.info("test_concat_04")
  78. data1 = ds.GeneratorDataset(generator, ["col1"])
  79. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  80. data2 = data2.batch(3)
  81. data3 = data1 + data2
  82. try:
  83. for _, _ in enumerate(data3):
  84. pass
  85. assert False
  86. except RuntimeError:
  87. pass
  88. def test_concat_05():
  89. """
  90. Test concat: test concat dataset that has different data type
  91. """
  92. logger.info("test_concat_05")
  93. data1 = ds.GeneratorDataset(generator, ["col1"])
  94. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  95. type_cast_op = C.TypeCast(mstype.float32)
  96. data1 = data1.map(input_columns=["col1"], operations=type_cast_op)
  97. data3 = data1 + data2
  98. try:
  99. for _, _ in enumerate(data3):
  100. pass
  101. assert False
  102. except RuntimeError:
  103. pass
  104. def test_concat_06():
  105. """
  106. Test concat: test concat multi datasets in one time
  107. """
  108. logger.info("test_concat_06")
  109. data1 = ds.GeneratorDataset(generator, ["col1"])
  110. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  111. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  112. dataset = data1 + data2 + data3
  113. # Here i refers to index, d refers to data element
  114. for i, d in enumerate(dataset):
  115. logger.info("data: %i", d[0][0])
  116. assert i == d[0][0]
  117. assert sum([1 for _ in dataset]) == 20
  118. def test_concat_07():
  119. """
  120. Test concat: test concat one dataset with multi datasets (datasets list)
  121. """
  122. logger.info("test_concat_07")
  123. data1 = ds.GeneratorDataset(generator, ["col1"])
  124. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  125. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  126. dataset = [data2] + [data3]
  127. data4 = data1 + dataset
  128. # Here i refers to index, d refers to data element
  129. for i, d in enumerate(data4):
  130. logger.info("data: %i", d[0][0])
  131. assert i == d[0][0]
  132. assert sum([1 for _ in data4]) == 20
  133. def test_concat_08():
  134. """
  135. Test concat: test concat 2 datasets, and then repeat
  136. """
  137. logger.info("test_concat_08")
  138. data1 = ds.GeneratorDataset(generator, ["col1"])
  139. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  140. data3 = data1 + data2
  141. data3 = data3.repeat(2)
  142. # Here i refers to index, d refers to data element
  143. for i, d in enumerate(data3):
  144. logger.info("data: %i", d[0][0])
  145. assert i % 10 == d[0][0]
  146. assert sum([1 for _ in data3]) == 20
  147. def test_concat_09():
  148. """
  149. Test concat: test concat 2 datasets, both of them have been repeat before
  150. """
  151. logger.info("test_concat_09")
  152. data1 = ds.GeneratorDataset(generator, ["col1"])
  153. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  154. data1 = data1.repeat(2)
  155. data2 = data2.repeat(2)
  156. data3 = data1 + data2
  157. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9]
  158. # Here i refers to index, d refers to data element
  159. for i, d in enumerate(data3):
  160. logger.info("data: %i", d[0][0])
  161. assert res[i] == d[0][0]
  162. assert sum([1 for _ in data3]) == 20
  163. def test_concat_10():
  164. """
  165. Test concat: test concat 2 datasets, one of them have repeat before
  166. """
  167. logger.info("test_concat_10")
  168. data1 = ds.GeneratorDataset(generator, ["col1"])
  169. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  170. data1 = data1.repeat(2)
  171. data3 = data1 + data2
  172. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  173. # Here i refers to index, d refers to data element
  174. for i, d in enumerate(data3):
  175. logger.info("data: %i", d[0][0])
  176. assert res[i] == d[0][0]
  177. assert sum([1 for _ in data3]) == 13
  178. def test_concat_11():
  179. """
  180. Test concat: test dataset batch then concat
  181. """
  182. logger.info("test_concat_11")
  183. data1 = ds.GeneratorDataset(generator, ["col1"])
  184. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  185. data1 = data1.batch(3)
  186. data2 = data2.batch(5)
  187. data3 = data1 + data2
  188. res = [0, 10, 15, 20]
  189. # Here i refers to index, d refers to data element
  190. for i, d in enumerate(data3):
  191. logger.info("data: %i", d[0][0])
  192. assert res[i] == d[0][0]
  193. assert sum([1 for _ in data3]) == 3
  194. def test_concat_12():
  195. """
  196. Test concat: test dataset concat then shuffle
  197. """
  198. logger.info("test_concat_12")
  199. data1 = ds.GeneratorDataset(generator, ["col1"])
  200. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  201. data1.set_dataset_size(3)
  202. data2.set_dataset_size(7)
  203. data3 = data1 + data2
  204. res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
  205. ds.config.set_seed(1)
  206. assert data3.get_dataset_size() == 10
  207. data3 = data3.shuffle(buffer_size=10)
  208. # Here i refers to index, d refers to data element
  209. for i, d in enumerate(data3):
  210. logger.info("data: %i", d[0][0])
  211. assert res[i] == d[0][0]
  212. assert sum([1 for _ in data3]) == 10
  213. def test_concat_13():
  214. """
  215. Test concat: test dataset batch then shuffle and concat
  216. """
  217. logger.info("test_concat_13")
  218. data1 = ds.GeneratorDataset(generator, ["col1"])
  219. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  220. data1.set_dataset_size(3)
  221. data2.set_dataset_size(10)
  222. data1 = data1.batch(3)
  223. data2 = data2.batch(5)
  224. data3 = data1 + data2
  225. res = [15, 0, 10]
  226. ds.config.set_seed(1)
  227. assert data3.get_dataset_size() == 3
  228. data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size()))
  229. # Here i refers to index, d refers to data element
  230. for i, d in enumerate(data3):
  231. logger.info("data: %i", d[0][0])
  232. assert res[i] == d[0][0]
  233. assert sum([1 for _ in data3]) == 3
  234. def test_concat_14():
  235. """
  236. Test concat: create dataset with different dataset folder, and do diffrent operation then concat
  237. """
  238. logger.info("test_concat_14")
  239. DATA_DIR = "../data/dataset/testPK/data"
  240. DATA_DIR2 = "../data/dataset/testImageNetData/train/"
  241. data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=3)
  242. data2 = ds.ImageFolderDatasetV2(DATA_DIR2, num_samples=2)
  243. transforms1 = F.ComposeOp([F.Decode(),
  244. F.Resize((224, 224)),
  245. F.ToTensor()])
  246. data1 = data1.map(input_columns=["image"], operations=transforms1())
  247. data2 = data2.map(input_columns=["image"], operations=transforms1())
  248. data3 = data1 + data2
  249. expected, output = [], []
  250. for d in data1:
  251. expected.append(d[0])
  252. for d in data2:
  253. expected.append(d[0])
  254. for d in data3:
  255. output.append(d[0])
  256. assert len(expected) == len(output)
  257. np.array_equal(np.array(output), np.array(expected))
  258. assert sum([1 for _ in data3]) == 5
  259. assert data3.get_dataset_size() == 5
  260. def test_concat_15():
  261. """
  262. Test concat: create dataset with different format of dataset file, and then concat
  263. """
  264. logger.info("test_concat_15")
  265. DATA_DIR = "../data/dataset/testPK/data"
  266. DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  267. data1 = ds.ImageFolderDatasetV2(DATA_DIR)
  268. data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
  269. data1 = data1.project(["image"])
  270. data3 = data1 + data2
  271. assert sum([1 for _ in data3]) == 47
  272. if __name__ == "__main__":
  273. test_concat_01()
  274. test_concat_02()
  275. test_concat_03()
  276. test_concat_04()
  277. test_concat_05()
  278. test_concat_06()
  279. test_concat_07()
  280. test_concat_08()
  281. test_concat_09()
  282. test_concat_10()
  283. test_concat_11()
  284. test_concat_12()
  285. test_concat_13()
  286. test_concat_14()
  287. test_concat_15()