You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_epoch_ctrl.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing Epoch Control op in DE
  17. """
  18. import itertools
  19. import cv2
  20. import numpy as np
  21. import pytest
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.transforms.vision.c_transforms as vision
  24. from mindspore import log as logger
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. def diff_mse(in1, in2):
  28. """
  29. diff_mse
  30. """
  31. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  32. return mse * 100
  33. def test_cifar10():
  34. """
  35. dataset parameter
  36. """
  37. logger.info("Test dataset parameter")
  38. data_dir_10 = "../data/dataset/testCifar10Data"
  39. num_repeat = 2
  40. batch_size = 32
  41. limit_dataset = 100
  42. # apply dataset operations
  43. data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
  44. data1 = data1.repeat(num_repeat)
  45. data1 = data1.batch(batch_size, True)
  46. num_epoch = 5
  47. # iter1 will always assume there is a next epoch and never shutdown.
  48. iter1 = data1.create_tuple_iterator()
  49. epoch_count = 0
  50. sample_count = 0
  51. for _ in range(num_epoch):
  52. row_count = 0
  53. for _ in iter1:
  54. # in this example, each dictionary has keys "image" and "label"
  55. row_count += 1
  56. assert row_count == int(limit_dataset * num_repeat / batch_size)
  57. logger.debug("row_count: ", row_count)
  58. epoch_count += 1
  59. sample_count += row_count
  60. assert epoch_count == num_epoch
  61. logger.debug("total epochs: ", epoch_count)
  62. assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
  63. logger.debug("total sample: ", sample_count)
  64. def test_decode_op():
  65. """
  66. Test Decode op
  67. """
  68. logger.info("test_decode_op")
  69. # Decode with rgb format set to True
  70. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  71. # Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
  72. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  73. # Second dataset
  74. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  75. num_epoch = 5
  76. # iter1 will always assume there is a next epoch and never shutdown.
  77. iter1 = data1.create_dict_iterator()
  78. # iter 2 will stop and shutdown pipeline after num_epoch
  79. iter2 = data2.create_dict_iterator(num_epoch)
  80. for _ in range(num_epoch):
  81. i = 0
  82. for item1, item2 in itertools.zip_longest(iter1, iter2):
  83. actual = item1["image"]
  84. expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
  85. expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
  86. assert actual.shape == expected.shape
  87. diff = actual - expected
  88. mse = np.sum(np.power(diff, 2))
  89. assert mse == 0
  90. i = i + 1
  91. assert i == 3
  92. # Users have the option to manually stop the iterator, or rely on garbage collector.
  93. iter1.stop()
  94. # Expect a AttributeError since iter1 has been stopped.
  95. with pytest.raises(AttributeError) as info:
  96. iter1.__next__()
  97. assert "object has no attribute 'depipeline'" in str(info.value)
  98. with pytest.raises(RuntimeError) as info:
  99. iter2.__next__()
  100. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  101. assert err_msg in str(info.value)
  102. # Generate 1d int numpy array from 0 - 63
  103. def generator_1d():
  104. """
  105. generator
  106. """
  107. for i in range(64):
  108. yield (np.array([i]),)
  109. def test_generator_dict_0():
  110. """
  111. test generator dict 0
  112. """
  113. logger.info("Test 1D Generator : 0 - 63")
  114. # apply dataset operations
  115. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  116. i = 0
  117. # create the iterator inside the loop declaration
  118. for item in data1.create_dict_iterator(): # each data is a dictionary
  119. golden = np.array([i])
  120. assert np.array_equal(item["data"], golden)
  121. i = i + 1
  122. def test_generator_dict_1():
  123. """
  124. test generator dict 1
  125. """
  126. logger.info("Test 1D Generator : 0 - 63")
  127. # apply dataset operations
  128. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  129. for _ in range(10):
  130. i = 0
  131. # BAD. Do not create iterator every time inside.
  132. # Create iterator outside the epoch for loop.
  133. for item in data1.create_dict_iterator(): # each data is a dictionary
  134. golden = np.array([i])
  135. assert np.array_equal(item["data"], golden)
  136. i = i + 1
  137. assert i == 64
  138. def test_generator_dict_2():
  139. """
  140. test generator dict 2
  141. """
  142. logger.info("Test 1D Generator : 0 - 63")
  143. # apply dataset operations
  144. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  145. iter1 = data1.create_dict_iterator()
  146. for _ in range(10):
  147. i = 0
  148. for item in iter1: # each data is a dictionary
  149. golden = np.array([i])
  150. assert np.array_equal(item["data"], golden)
  151. i = i + 1
  152. assert i == 64
  153. # iter1 is still alive and running.
  154. item1 = iter1.__next__()
  155. assert item1
  156. # rely on garbage collector to destroy iter1
  157. def test_generator_dict_3():
  158. """
  159. test generator dict 3
  160. """
  161. logger.info("Test 1D Generator : 0 - 63")
  162. # apply dataset operations
  163. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  164. iter1 = data1.create_dict_iterator()
  165. for _ in range(10):
  166. i = 0
  167. for item in iter1: # each data is a dictionary
  168. golden = np.array([i])
  169. assert np.array_equal(item["data"], golden)
  170. i = i + 1
  171. assert i == 64
  172. # optional
  173. iter1.stop()
  174. # Expect a AttributeError since iter1 has been stopped.
  175. with pytest.raises(AttributeError) as info:
  176. iter1.__next__()
  177. assert "object has no attribute 'depipeline'" in str(info.value)
  178. def test_generator_dict_4():
  179. """
  180. test generator dict 4
  181. """
  182. logger.info("Test 1D Generator : 0 - 63")
  183. # apply dataset operations
  184. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  185. iter1 = data1.create_dict_iterator(num_epochs=10)
  186. for _ in range(10):
  187. i = 0
  188. for item in iter1: # each data is a dictionary
  189. golden = np.array([i])
  190. assert np.array_equal(item["data"], golden)
  191. i = i + 1
  192. assert i == 64
  193. with pytest.raises(RuntimeError) as info:
  194. iter1.__next__()
  195. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  196. assert err_msg in str(info.value)
  197. def test_generator_dict_4_1():
  198. """
  199. test generator dict 4_1
  200. """
  201. logger.info("Test 1D Generator : 0 - 63")
  202. # apply dataset operations
  203. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  204. # epoch ctrl op will not be injected if num_epochs is 1.
  205. iter1 = data1.create_dict_iterator(num_epochs=1)
  206. for _ in range(1):
  207. i = 0
  208. for item in iter1: # each data is a dictionary
  209. golden = np.array([i])
  210. assert np.array_equal(item["data"], golden)
  211. i = i + 1
  212. assert i == 64
  213. with pytest.raises(RuntimeError) as info:
  214. iter1.__next__()
  215. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  216. assert err_msg in str(info.value)
  217. def test_generator_dict_4_2():
  218. """
  219. test generator dict 4_2
  220. """
  221. logger.info("Test 1D Generator : 0 - 63")
  222. # apply dataset operations
  223. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  224. # repeat will not be injected when num repeat is 1.
  225. data1 = data1.repeat(1)
  226. # epoch ctrl op will not be injected if num_epochs is 1.
  227. iter1 = data1.create_dict_iterator(num_epochs=1)
  228. for _ in range(1):
  229. i = 0
  230. for item in iter1: # each data is a dictionary
  231. golden = np.array([i])
  232. assert np.array_equal(item["data"], golden)
  233. i = i + 1
  234. assert i == 64
  235. with pytest.raises(RuntimeError) as info:
  236. iter1.__next__()
  237. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  238. assert err_msg in str(info.value)
  239. def test_generator_dict_5():
  240. """
  241. test generator dict 5
  242. """
  243. logger.info("Test 1D Generator : 0 - 63")
  244. # apply dataset operations
  245. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  246. iter1 = data1.create_dict_iterator(num_epochs=11)
  247. for _ in range(10):
  248. i = 0
  249. for item in iter1: # each data is a dictionary
  250. golden = np.array([i])
  251. assert np.array_equal(item["data"], golden)
  252. i = i + 1
  253. assert i == 64
  254. # still one more epoch left in the iter1.
  255. i = 0
  256. for item in iter1: # each data is a dictionary
  257. golden = np.array([i])
  258. assert np.array_equal(item["data"], golden)
  259. i = i + 1
  260. assert i == 64
  261. # now iter1 has been exhausted, c++ pipeline has been shut down.
  262. with pytest.raises(RuntimeError) as info:
  263. iter1.__next__()
  264. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  265. assert err_msg in str(info.value)
  266. # Test tuple iterator
  267. def test_generator_tuple_0():
  268. """
  269. test generator tuple 0
  270. """
  271. logger.info("Test 1D Generator : 0 - 63")
  272. # apply dataset operations
  273. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  274. i = 0
  275. # create the iterator inside the loop declaration
  276. for item in data1.create_tuple_iterator(): # each data is a dictionary
  277. golden = np.array([i])
  278. assert np.array_equal(item[0], golden)
  279. i = i + 1
  280. def test_generator_tuple_1():
  281. """
  282. test generator tuple 1
  283. """
  284. logger.info("Test 1D Generator : 0 - 63")
  285. # apply dataset operations
  286. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  287. for _ in range(10):
  288. i = 0
  289. # BAD. Do not create iterator every time inside.
  290. # Create iterator outside the epoch for loop.
  291. for item in data1.create_tuple_iterator(): # each data is a dictionary
  292. golden = np.array([i])
  293. assert np.array_equal(item[0], golden)
  294. i = i + 1
  295. assert i == 64
  296. def test_generator_tuple_2():
  297. """
  298. test generator tuple 2
  299. """
  300. logger.info("Test 1D Generator : 0 - 63")
  301. # apply dataset operations
  302. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  303. iter1 = data1.create_tuple_iterator()
  304. for _ in range(10):
  305. i = 0
  306. for item in iter1: # each data is a dictionary
  307. golden = np.array([i])
  308. assert np.array_equal(item[0], golden)
  309. i = i + 1
  310. assert i == 64
  311. # iter1 is still alive and running.
  312. item1 = iter1.__next__()
  313. assert item1
  314. # rely on garbage collector to destroy iter1
  315. def test_generator_tuple_3():
  316. """
  317. test generator tuple 3
  318. """
  319. logger.info("Test 1D Generator : 0 - 63")
  320. # apply dataset operations
  321. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  322. iter1 = data1.create_tuple_iterator()
  323. for _ in range(10):
  324. i = 0
  325. for item in iter1: # each data is a dictionary
  326. golden = np.array([i])
  327. assert np.array_equal(item[0], golden)
  328. i = i + 1
  329. assert i == 64
  330. # optional
  331. iter1.stop()
  332. # Expect a AttributeError since iter1 has been stopped.
  333. with pytest.raises(AttributeError) as info:
  334. iter1.__next__()
  335. assert "object has no attribute 'depipeline'" in str(info.value)
  336. def test_generator_tuple_4():
  337. """
  338. test generator tuple 4
  339. """
  340. logger.info("Test 1D Generator : 0 - 63")
  341. # apply dataset operations
  342. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  343. iter1 = data1.create_tuple_iterator(num_epochs=10)
  344. for _ in range(10):
  345. i = 0
  346. for item in iter1: # each data is a dictionary
  347. golden = np.array([i])
  348. assert np.array_equal(item[0], golden)
  349. i = i + 1
  350. assert i == 64
  351. with pytest.raises(RuntimeError) as info:
  352. iter1.__next__()
  353. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  354. assert err_msg in str(info.value)
  355. def test_generator_tuple_5():
  356. """
  357. test generator tuple 5
  358. """
  359. logger.info("Test 1D Generator : 0 - 63")
  360. # apply dataset operations
  361. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  362. iter1 = data1.create_tuple_iterator(num_epochs=11)
  363. for _ in range(10):
  364. i = 0
  365. for item in iter1: # each data is a dictionary
  366. golden = np.array([i])
  367. assert np.array_equal(item[0], golden)
  368. i = i + 1
  369. assert i == 64
  370. # still one more epoch left in the iter1.
  371. i = 0
  372. for item in iter1: # each data is a dictionary
  373. golden = np.array([i])
  374. assert np.array_equal(item[0], golden)
  375. i = i + 1
  376. assert i == 64
  377. # now iter1 has been exhausted, c++ pipeline has been shut down.
  378. with pytest.raises(RuntimeError) as info:
  379. iter1.__next__()
  380. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  381. assert err_msg in str(info.value)
  382. # Test with repeat
  383. def test_generator_tuple_repeat_1():
  384. """
  385. test generator tuple repeat 1
  386. """
  387. logger.info("Test 1D Generator : 0 - 63")
  388. # apply dataset operations
  389. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  390. data1 = data1.repeat(2)
  391. iter1 = data1.create_tuple_iterator(num_epochs=11)
  392. for _ in range(10):
  393. i = 0
  394. for item in iter1: # each data is a dictionary
  395. golden = np.array([i % 64])
  396. assert np.array_equal(item[0], golden)
  397. i = i + 1
  398. assert i == 64 * 2
  399. # still one more epoch left in the iter1.
  400. i = 0
  401. for item in iter1: # each data is a dictionary
  402. golden = np.array([i % 64])
  403. assert np.array_equal(item[0], golden)
  404. i = i + 1
  405. assert i == 64 * 2
  406. # now iter1 has been exhausted, c++ pipeline has been shut down.
  407. with pytest.raises(RuntimeError) as info:
  408. iter1.__next__()
  409. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  410. assert err_msg in str(info.value)
  411. # Test with repeat
  412. def test_generator_tuple_repeat_repeat_1():
  413. """
  414. test generator tuple repeat repeat 1
  415. """
  416. logger.info("Test 1D Generator : 0 - 63")
  417. # apply dataset operations
  418. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  419. data1 = data1.repeat(2)
  420. data1 = data1.repeat(3)
  421. iter1 = data1.create_tuple_iterator(num_epochs=11)
  422. for _ in range(10):
  423. i = 0
  424. for item in iter1: # each data is a dictionary
  425. golden = np.array([i % 64])
  426. assert np.array_equal(item[0], golden)
  427. i = i + 1
  428. assert i == 64 * 2 * 3
  429. # still one more epoch left in the iter1.
  430. i = 0
  431. for item in iter1: # each data is a dictionary
  432. golden = np.array([i % 64])
  433. assert np.array_equal(item[0], golden)
  434. i = i + 1
  435. assert i == 64 * 2 * 3
  436. # now iter1 has been exhausted, c++ pipeline has been shut down.
  437. with pytest.raises(RuntimeError) as info:
  438. iter1.__next__()
  439. err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
  440. assert err_msg in str(info.value)
  441. def test_generator_tuple_repeat_repeat_2():
  442. """
  443. test generator tuple repeat repeat 2
  444. """
  445. logger.info("Test 1D Generator : 0 - 63")
  446. # apply dataset operations
  447. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  448. data1 = data1.repeat(2)
  449. data1 = data1.repeat(3)
  450. iter1 = data1.create_tuple_iterator()
  451. for _ in range(10):
  452. i = 0
  453. for item in iter1: # each data is a dictionary
  454. golden = np.array([i % 64])
  455. assert np.array_equal(item[0], golden)
  456. i = i + 1
  457. assert i == 64 * 2 * 3
  458. # optional
  459. iter1.stop()
  460. # Expect a AttributeError since iter1 has been stopped.
  461. with pytest.raises(AttributeError) as info:
  462. iter1.__next__()
  463. assert "object has no attribute 'depipeline'" in str(info.value)
  464. def test_generator_tuple_repeat_repeat_3():
  465. """
  466. test generator tuple repeat repeat 3
  467. """
  468. logger.info("Test 1D Generator : 0 - 63")
  469. # apply dataset operations
  470. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  471. data1 = data1.repeat(2)
  472. data1 = data1.repeat(3)
  473. iter1 = data1.create_tuple_iterator()
  474. for _ in range(10):
  475. i = 0
  476. for item in iter1: # each data is a dictionary
  477. golden = np.array([i % 64])
  478. assert np.array_equal(item[0], golden)
  479. i = i + 1
  480. assert i == 64 * 2 * 3
  481. for _ in range(5):
  482. i = 0
  483. for item in iter1: # each data is a dictionary
  484. golden = np.array([i % 64])
  485. assert np.array_equal(item[0], golden)
  486. i = i + 1
  487. assert i == 64 * 2 * 3
  488. # rely on garbage collector to destroy iter1
  489. def test_generator_reusedataset():
  490. """
  491. test generator reusedataset
  492. """
  493. logger.info("Test 1D Generator : 0 - 63")
  494. # apply dataset operations
  495. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  496. data1 = data1.repeat(2)
  497. iter1 = data1.create_tuple_iterator()
  498. for _ in range(10):
  499. i = 0
  500. for item in iter1: # each data is a dictionary
  501. golden = np.array([i % 64])
  502. assert np.array_equal(item[0], golden)
  503. i = i + 1
  504. assert i == 64 * 2
  505. data1 = data1.repeat(3)
  506. iter1 = data1.create_tuple_iterator()
  507. for _ in range(5):
  508. i = 0
  509. for item in iter1: # each data is a dictionary
  510. golden = np.array([i % 64])
  511. assert np.array_equal(item[0], golden)
  512. i = i + 1
  513. assert i == 64 * 2 * 3
  514. data1 = data1.batch(2)
  515. iter1 = data1.create_dict_iterator()
  516. for _ in range(5):
  517. i = 0
  518. sample = 0
  519. for item in iter1: # each data is a dictionary
  520. golden = np.array([[i % 64], [(i + 1) % 64]])
  521. assert np.array_equal(item["data"], golden)
  522. i = i + 2
  523. sample = sample + 1
  524. assert sample == 64 * 3
  525. # rely on garbage collector to destroy iter1