You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_random_crop.py 22 kB

5 years ago

  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RandomCrop op in DE
  17. """
  18. import numpy as np
  19. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  20. import mindspore.dataset.transforms.vision.py_transforms as py_vision
  21. import mindspore.dataset.transforms.vision.utils as mode
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, visualize_list, config_get_set_seed, \
  25. config_get_set_num_parallel_workers
  26. GENERATE_GOLDEN = False
  27. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  28. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  29. def test_random_crop_op_c(plot=False):
  30. """
  31. Test RandomCrop Op in c transforms
  32. """
  33. logger.info("test_random_crop_op_c")
  34. # First dataset
  35. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  36. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  37. decode_op = c_vision.Decode()
  38. data1 = data1.map(input_columns=["image"], operations=decode_op)
  39. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  40. # Second dataset
  41. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  42. data2 = data2.map(input_columns=["image"], operations=decode_op)
  43. image_cropped = []
  44. image = []
  45. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  46. image1 = item1["image"]
  47. image2 = item2["image"]
  48. image_cropped.append(image1)
  49. image.append(image2)
  50. if plot:
  51. visualize_list(image, image_cropped)
  52. def test_random_crop_op_py(plot=False):
  53. """
  54. Test RandomCrop op in py transforms
  55. """
  56. logger.info("test_random_crop_op_py")
  57. # First dataset
  58. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  59. transforms1 = [
  60. py_vision.Decode(),
  61. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  62. py_vision.ToTensor()
  63. ]
  64. transform1 = py_vision.ComposeOp(transforms1)
  65. data1 = data1.map(input_columns=["image"], operations=transform1())
  66. # Second dataset
  67. # Second dataset for comparison
  68. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  69. transforms2 = [
  70. py_vision.Decode(),
  71. py_vision.ToTensor()
  72. ]
  73. transform2 = py_vision.ComposeOp(transforms2)
  74. data2 = data2.map(input_columns=["image"], operations=transform2())
  75. crop_images = []
  76. original_images = []
  77. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  78. crop = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  79. original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  80. crop_images.append(crop)
  81. original_images.append(original)
  82. if plot:
  83. visualize_list(original_images, crop_images)
  84. def test_random_crop_01_c():
  85. """
  86. Test RandomCrop op with c_transforms: size is a single integer, expected to pass
  87. """
  88. logger.info("test_random_crop_01_c")
  89. original_seed = config_get_set_seed(0)
  90. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  91. # Generate dataset
  92. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  93. # Note: If size is an int, a square crop of size (size, size) is returned.
  94. random_crop_op = c_vision.RandomCrop(512)
  95. decode_op = c_vision.Decode()
  96. data = data.map(input_columns=["image"], operations=decode_op)
  97. data = data.map(input_columns=["image"], operations=random_crop_op)
  98. filename = "random_crop_01_c_result.npz"
  99. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  100. # Restore config setting
  101. ds.config.set_seed(original_seed)
  102. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  103. def test_random_crop_01_py():
  104. """
  105. Test RandomCrop op with py_transforms: size is a single integer, expected to pass
  106. """
  107. logger.info("test_random_crop_01_py")
  108. original_seed = config_get_set_seed(0)
  109. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  110. # Generate dataset
  111. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  112. # Note: If size is an int, a square crop of size (size, size) is returned.
  113. transforms = [
  114. py_vision.Decode(),
  115. py_vision.RandomCrop(512),
  116. py_vision.ToTensor()
  117. ]
  118. transform = py_vision.ComposeOp(transforms)
  119. data = data.map(input_columns=["image"], operations=transform())
  120. filename = "random_crop_01_py_result.npz"
  121. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  122. # Restore config setting
  123. ds.config.set_seed(original_seed)
  124. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  125. def test_random_crop_02_c():
  126. """
  127. Test RandomCrop op with c_transforms: size is a list/tuple with length 2, expected to pass
  128. """
  129. logger.info("test_random_crop_02_c")
  130. original_seed = config_get_set_seed(0)
  131. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  132. # Generate dataset
  133. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  134. # Note: If size is a sequence of length 2, it should be (height, width).
  135. random_crop_op = c_vision.RandomCrop([512, 375])
  136. decode_op = c_vision.Decode()
  137. data = data.map(input_columns=["image"], operations=decode_op)
  138. data = data.map(input_columns=["image"], operations=random_crop_op)
  139. filename = "random_crop_02_c_result.npz"
  140. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  141. # Restore config setting
  142. ds.config.set_seed(original_seed)
  143. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  144. def test_random_crop_02_py():
  145. """
  146. Test RandomCrop op with py_transforms: size is a list/tuple with length 2, expected to pass
  147. """
  148. logger.info("test_random_crop_02_py")
  149. original_seed = config_get_set_seed(0)
  150. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  151. # Generate dataset
  152. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  153. # Note: If size is a sequence of length 2, it should be (height, width).
  154. transforms = [
  155. py_vision.Decode(),
  156. py_vision.RandomCrop([512, 375]),
  157. py_vision.ToTensor()
  158. ]
  159. transform = py_vision.ComposeOp(transforms)
  160. data = data.map(input_columns=["image"], operations=transform())
  161. filename = "random_crop_02_py_result.npz"
  162. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  163. # Restore config setting
  164. ds.config.set_seed(original_seed)
  165. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  166. def test_random_crop_03_c():
  167. """
  168. Test RandomCrop op with c_transforms: input image size == crop size, expected to pass
  169. """
  170. logger.info("test_random_crop_03_c")
  171. original_seed = config_get_set_seed(0)
  172. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  173. # Generate dataset
  174. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  175. # Note: The size of the image is 4032*2268
  176. random_crop_op = c_vision.RandomCrop([2268, 4032])
  177. decode_op = c_vision.Decode()
  178. data = data.map(input_columns=["image"], operations=decode_op)
  179. data = data.map(input_columns=["image"], operations=random_crop_op)
  180. filename = "random_crop_03_c_result.npz"
  181. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  182. # Restore config setting
  183. ds.config.set_seed(original_seed)
  184. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  185. def test_random_crop_03_py():
  186. """
  187. Test RandomCrop op with py_transforms: input image size == crop size, expected to pass
  188. """
  189. logger.info("test_random_crop_03_py")
  190. original_seed = config_get_set_seed(0)
  191. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  192. # Generate dataset
  193. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  194. # Note: The size of the image is 4032*2268
  195. transforms = [
  196. py_vision.Decode(),
  197. py_vision.RandomCrop([2268, 4032]),
  198. py_vision.ToTensor()
  199. ]
  200. transform = py_vision.ComposeOp(transforms)
  201. data = data.map(input_columns=["image"], operations=transform())
  202. filename = "random_crop_03_py_result.npz"
  203. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  204. # Restore config setting
  205. ds.config.set_seed(original_seed)
  206. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  207. def test_random_crop_04_c():
  208. """
  209. Test RandomCrop op with c_transforms: input image size < crop size, expected to fail
  210. """
  211. logger.info("test_random_crop_04_c")
  212. # Generate dataset
  213. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  214. # Note: The size of the image is 4032*2268
  215. random_crop_op = c_vision.RandomCrop([2268, 4033])
  216. decode_op = c_vision.Decode()
  217. data = data.map(input_columns=["image"], operations=decode_op)
  218. data = data.map(input_columns=["image"], operations=random_crop_op)
  219. try:
  220. data.create_dict_iterator().get_next()
  221. except RuntimeError as e:
  222. logger.info("Got an exception in DE: {}".format(str(e)))
  223. assert "Crop size is greater than the image dim" in str(e)
  224. def test_random_crop_04_py():
  225. """
  226. Test RandomCrop op with py_transforms:
  227. input image size < crop size, expected to fail
  228. """
  229. logger.info("test_random_crop_04_py")
  230. # Generate dataset
  231. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  232. # Note: The size of the image is 4032*2268
  233. transforms = [
  234. py_vision.Decode(),
  235. py_vision.RandomCrop([2268, 4033]),
  236. py_vision.ToTensor()
  237. ]
  238. transform = py_vision.ComposeOp(transforms)
  239. data = data.map(input_columns=["image"], operations=transform())
  240. try:
  241. data.create_dict_iterator().get_next()
  242. except RuntimeError as e:
  243. logger.info("Got an exception in DE: {}".format(str(e)))
  244. assert "Crop size" in str(e)
  245. def test_random_crop_05_c():
  246. """
  247. Test RandomCrop op with c_transforms:
  248. input image size < crop size but pad_if_needed is enabled,
  249. expected to pass
  250. """
  251. logger.info("test_random_crop_05_c")
  252. original_seed = config_get_set_seed(0)
  253. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  254. # Generate dataset
  255. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  256. # Note: The size of the image is 4032*2268
  257. random_crop_op = c_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True)
  258. decode_op = c_vision.Decode()
  259. data = data.map(input_columns=["image"], operations=decode_op)
  260. data = data.map(input_columns=["image"], operations=random_crop_op)
  261. filename = "random_crop_05_c_result.npz"
  262. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  263. # Restore config setting
  264. ds.config.set_seed(original_seed)
  265. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  266. def test_random_crop_05_py():
  267. """
  268. Test RandomCrop op with py_transforms:
  269. input image size < crop size but pad_if_needed is enabled,
  270. expected to pass
  271. """
  272. logger.info("test_random_crop_05_py")
  273. original_seed = config_get_set_seed(0)
  274. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  275. # Generate dataset
  276. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  277. # Note: The size of the image is 4032*2268
  278. transforms = [
  279. py_vision.Decode(),
  280. py_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True),
  281. py_vision.ToTensor()
  282. ]
  283. transform = py_vision.ComposeOp(transforms)
  284. data = data.map(input_columns=["image"], operations=transform())
  285. filename = "random_crop_05_py_result.npz"
  286. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  287. # Restore config setting
  288. ds.config.set_seed(original_seed)
  289. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  290. def test_random_crop_06_c():
  291. """
  292. Test RandomCrop op with c_transforms:
  293. invalid size, expected to raise TypeError
  294. """
  295. logger.info("test_random_crop_06_c")
  296. # Generate dataset
  297. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  298. try:
  299. # Note: if size is neither an int nor a list of length 2, an exception will raise
  300. random_crop_op = c_vision.RandomCrop([512, 512, 375])
  301. decode_op = c_vision.Decode()
  302. data = data.map(input_columns=["image"], operations=decode_op)
  303. data = data.map(input_columns=["image"], operations=random_crop_op)
  304. except TypeError as e:
  305. logger.info("Got an exception in DE: {}".format(str(e)))
  306. assert "Size should be a single integer" in str(e)
  307. def test_random_crop_06_py():
  308. """
  309. Test RandomCrop op with py_transforms:
  310. invalid size, expected to raise TypeError
  311. """
  312. logger.info("test_random_crop_06_py")
  313. # Generate dataset
  314. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  315. try:
  316. # Note: if size is neither an int nor a list of length 2, an exception will raise
  317. transforms = [
  318. py_vision.Decode(),
  319. py_vision.RandomCrop([512, 512, 375]),
  320. py_vision.ToTensor()
  321. ]
  322. transform = py_vision.ComposeOp(transforms)
  323. data = data.map(input_columns=["image"], operations=transform())
  324. except TypeError as e:
  325. logger.info("Got an exception in DE: {}".format(str(e)))
  326. assert "Size should be a single integer" in str(e)
  327. def test_random_crop_07_c():
  328. """
  329. Test RandomCrop op with c_transforms:
  330. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  331. expected to pass
  332. """
  333. logger.info("test_random_crop_07_c")
  334. original_seed = config_get_set_seed(0)
  335. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  336. # Generate dataset
  337. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  338. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  339. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
  340. decode_op = c_vision.Decode()
  341. data = data.map(input_columns=["image"], operations=decode_op)
  342. data = data.map(input_columns=["image"], operations=random_crop_op)
  343. filename = "random_crop_07_c_result.npz"
  344. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  345. # Restore config setting
  346. ds.config.set_seed(original_seed)
  347. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  348. def test_random_crop_07_py():
  349. """
  350. Test RandomCrop op with py_transforms:
  351. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  352. expected to pass
  353. """
  354. logger.info("test_random_crop_07_py")
  355. original_seed = config_get_set_seed(0)
  356. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  357. # Generate dataset
  358. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  359. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  360. transforms = [
  361. py_vision.Decode(),
  362. py_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255)),
  363. py_vision.ToTensor()
  364. ]
  365. transform = py_vision.ComposeOp(transforms)
  366. data = data.map(input_columns=["image"], operations=transform())
  367. filename = "random_crop_07_py_result.npz"
  368. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  369. # Restore config setting
  370. ds.config.set_seed(original_seed)
  371. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  372. def test_random_crop_08_c():
  373. """
  374. Test RandomCrop op with c_transforms: padding_mode is Border.EDGE,
  375. expected to pass
  376. """
  377. logger.info("test_random_crop_08_c")
  378. original_seed = config_get_set_seed(0)
  379. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  380. # Generate dataset
  381. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  382. # Note: The padding_mode is Border.EDGE.
  383. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
  384. decode_op = c_vision.Decode()
  385. data = data.map(input_columns=["image"], operations=decode_op)
  386. data = data.map(input_columns=["image"], operations=random_crop_op)
  387. filename = "random_crop_08_c_result.npz"
  388. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  389. # Restore config setting
  390. ds.config.set_seed(original_seed)
  391. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  392. def test_random_crop_08_py():
  393. """
  394. Test RandomCrop op with py_transforms: padding_mode is Border.EDGE,
  395. expected to pass
  396. """
  397. logger.info("test_random_crop_08_py")
  398. original_seed = config_get_set_seed(0)
  399. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  400. # Generate dataset
  401. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  402. # Note: The padding_mode is Border.EDGE.
  403. transforms = [
  404. py_vision.Decode(),
  405. py_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE),
  406. py_vision.ToTensor()
  407. ]
  408. transform = py_vision.ComposeOp(transforms)
  409. data = data.map(input_columns=["image"], operations=transform())
  410. filename = "random_crop_08_py_result.npz"
  411. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  412. # Restore config setting
  413. ds.config.set_seed(original_seed)
  414. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  415. def test_random_crop_09():
  416. """
  417. Test RandomCrop op: invalid type of input image (not PIL), expected to raise TypeError
  418. """
  419. logger.info("test_random_crop_09")
  420. # Generate dataset
  421. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  422. transforms = [
  423. py_vision.Decode(),
  424. py_vision.ToTensor(),
  425. # Note: if input is not PIL image, TypeError will raise
  426. py_vision.RandomCrop(512)
  427. ]
  428. transform = py_vision.ComposeOp(transforms)
  429. data = data.map(input_columns=["image"], operations=transform())
  430. try:
  431. data.create_dict_iterator().get_next()
  432. except RuntimeError as e:
  433. logger.info("Got an exception in DE: {}".format(str(e)))
  434. assert "should be PIL Image" in str(e)
  435. def test_random_crop_comp(plot=False):
  436. """
  437. Test RandomCrop and compare between python and c image augmentation
  438. """
  439. logger.info("Test RandomCrop with c_transform and py_transform comparison")
  440. cropped_size = 512
  441. # First dataset
  442. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  443. random_crop_op = c_vision.RandomCrop(cropped_size)
  444. decode_op = c_vision.Decode()
  445. data1 = data1.map(input_columns=["image"], operations=decode_op)
  446. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  447. # Second dataset
  448. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  449. transforms = [
  450. py_vision.Decode(),
  451. py_vision.RandomCrop(cropped_size),
  452. py_vision.ToTensor()
  453. ]
  454. transform = py_vision.ComposeOp(transforms)
  455. data2 = data2.map(input_columns=["image"], operations=transform())
  456. image_c_cropped = []
  457. image_py_cropped = []
  458. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  459. c_image = item1["image"]
  460. py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  461. image_c_cropped.append(c_image)
  462. image_py_cropped.append(py_image)
  463. if plot:
  464. visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2)
  465. if __name__ == "__main__":
  466. test_random_crop_01_c()
  467. test_random_crop_02_c()
  468. test_random_crop_03_c()
  469. test_random_crop_04_c()
  470. test_random_crop_05_c()
  471. test_random_crop_06_c()
  472. test_random_crop_07_c()
  473. test_random_crop_08_c()
  474. test_random_crop_01_py()
  475. test_random_crop_02_py()
  476. test_random_crop_03_py()
  477. test_random_crop_04_py()
  478. test_random_crop_05_py()
  479. test_random_crop_06_py()
  480. test_random_crop_07_py()
  481. test_random_crop_08_py()
  482. test_random_crop_09()
  483. test_random_crop_op_c(True)
  484. test_random_crop_op_py(True)
  485. test_random_crop_comp(True)