You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rgb_hsv.py 6.4 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RgbToHsv and HsvToRgb op in DE
  17. """
  18. import colorsys
  19. import numpy as np
  20. from numpy.testing import assert_allclose
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.transforms.vision.py_transforms as vision
  23. import mindspore.dataset.transforms.vision.py_transforms_util as util
  24. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  25. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  26. def generate_numpy_random_rgb(shape):
  27. # Only generate floating points that are fractions like n / 256, since they
  28. # are RGB pixels. Some low-precision floating point types in this test can't
  29. # handle arbitrary precision floating points well.
  30. return np.random.randint(0, 256, shape) / 255.
  31. def test_rgb_hsv_hwc():
  32. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  33. rgb_np = rgb_flat.reshape((8, 8, 3))
  34. hsv_base = np.array([
  35. colorsys.rgb_to_hsv(
  36. r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
  37. for r, g, b in rgb_flat
  38. ])
  39. hsv_base = hsv_base.reshape((8, 8, 3))
  40. hsv_de = util.rgb_to_hsvs(rgb_np, True)
  41. assert hsv_base.shape == hsv_de.shape
  42. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  43. hsv_flat = hsv_base.reshape(64, 3)
  44. rgb_base = np.array([
  45. colorsys.hsv_to_rgb(
  46. h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
  47. for h, s, v in hsv_flat
  48. ])
  49. rgb_base = rgb_base.reshape((8, 8, 3))
  50. rgb_de = util.hsv_to_rgbs(hsv_base, True)
  51. assert rgb_base.shape == rgb_de.shape
  52. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  53. def test_rgb_hsv_batch_hwc():
  54. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  55. rgb_np = rgb_flat.reshape((4, 2, 8, 3))
  56. hsv_base = np.array([
  57. colorsys.rgb_to_hsv(
  58. r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
  59. for r, g, b in rgb_flat
  60. ])
  61. hsv_base = hsv_base.reshape((4, 2, 8, 3))
  62. hsv_de = util.rgb_to_hsvs(rgb_np, True)
  63. assert hsv_base.shape == hsv_de.shape
  64. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  65. hsv_flat = hsv_base.reshape((64, 3))
  66. rgb_base = np.array([
  67. colorsys.hsv_to_rgb(
  68. h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
  69. for h, s, v in hsv_flat
  70. ])
  71. rgb_base = rgb_base.reshape((4, 2, 8, 3))
  72. rgb_de = util.hsv_to_rgbs(hsv_base, True)
  73. assert rgb_de.shape == rgb_base.shape
  74. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  75. def test_rgb_hsv_chw():
  76. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  77. rgb_np = rgb_flat.reshape((3, 8, 8))
  78. hsv_base = np.array([
  79. np.vectorize(colorsys.rgb_to_hsv)(
  80. rgb_np[0, :, :].astype(np.float64), rgb_np[1, :, :].astype(np.float64), rgb_np[2, :, :].astype(np.float64))
  81. ])
  82. hsv_base = hsv_base.reshape((3, 8, 8))
  83. hsv_de = util.rgb_to_hsvs(rgb_np, False)
  84. assert hsv_base.shape == hsv_de.shape
  85. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  86. rgb_base = np.array([
  87. np.vectorize(colorsys.hsv_to_rgb)(
  88. hsv_base[0, :, :].astype(np.float64), hsv_base[1, :, :].astype(np.float64),
  89. hsv_base[2, :, :].astype(np.float64))
  90. ])
  91. rgb_base = rgb_base.reshape((3, 8, 8))
  92. rgb_de = util.hsv_to_rgbs(hsv_base, False)
  93. assert rgb_de.shape == rgb_base.shape
  94. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  95. def test_rgb_hsv_batch_chw():
  96. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  97. rgb_imgs = rgb_flat.reshape((4, 3, 2, 8))
  98. hsv_base_imgs = np.array([
  99. np.vectorize(colorsys.rgb_to_hsv)(
  100. img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
  101. for img in rgb_imgs
  102. ])
  103. hsv_de = util.rgb_to_hsvs(rgb_imgs, False)
  104. assert hsv_base_imgs.shape == hsv_de.shape
  105. assert_allclose(hsv_base_imgs.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  106. rgb_base = np.array([
  107. np.vectorize(colorsys.hsv_to_rgb)(
  108. img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
  109. for img in hsv_base_imgs
  110. ])
  111. rgb_de = util.hsv_to_rgbs(hsv_base_imgs, False)
  112. assert rgb_base.shape == rgb_de.shape
  113. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  114. def test_rgb_hsv_pipeline():
  115. # First dataset
  116. transforms1 = [
  117. vision.Decode(),
  118. vision.Resize([64, 64]),
  119. vision.ToTensor()
  120. ]
  121. transforms1 = vision.ComposeOp(transforms1)
  122. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  123. ds1 = ds1.map(input_columns=["image"], operations=transforms1())
  124. # Second dataset
  125. transforms2 = [
  126. vision.Decode(),
  127. vision.Resize([64, 64]),
  128. vision.ToTensor(),
  129. vision.RgbToHsv(),
  130. vision.HsvToRgb()
  131. ]
  132. transform2 = vision.ComposeOp(transforms2)
  133. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  134. ds2 = ds2.map(input_columns=["image"], operations=transform2())
  135. num_iter = 0
  136. for data1, data2 in zip(ds1.create_dict_iterator(), ds2.create_dict_iterator()):
  137. num_iter += 1
  138. ori_img = data1["image"]
  139. cvt_img = data2["image"]
  140. assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0)
  141. assert ori_img.shape == cvt_img.shape
  142. if __name__ == "__main__":
  143. test_rgb_hsv_hwc()
  144. test_rgb_hsv_batch_hwc()
  145. test_rgb_hsv_chw()
  146. test_rgb_hsv_batch_chw()
  147. test_rgb_hsv_pipeline()