You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datafile_encap.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Datafile encapsulator."""
  16. import io
  17. import os
  18. import numpy as np
  19. from PIL import Image
  20. from mindinsight.datavisual.common.exceptions import ImageNotExistError
  21. from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
  22. from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
  23. from mindinsight.utils.exceptions import FileSystemPermissionError
  24. from mindinsight.utils.exceptions import UnknownError
  25. # Max uint8 value. for converting RGB pixels to [0,1] intensity.
  26. _UINT8_MAX = 255
  27. # Color of low saliency.
  28. _SALIENCY_CMAP_LOW = (55, 25, 86, 255)
  29. # Color of high saliency.
  30. _SALIENCY_CMAP_HI = (255, 255, 0, 255)
  31. # Channel modes.
  32. _SINGLE_CHANNEL_MODE = "L"
  33. _RGBA_MODE = "RGBA"
  34. _RGB_MODE = "RGB"
  35. _PNG_FORMAT = "PNG"
  36. def _clean_train_id_b4_join(train_id):
  37. """Clean train_id before joining to a path."""
  38. if train_id.startswith("./") or train_id.startswith(".\\"):
  39. return train_id[2:]
  40. return train_id
  41. class DatafileEncap(ExplainDataEncap):
  42. """Datafile encapsulator."""
  43. def query_image_binary(self, train_id, image_path, image_type):
  44. """
  45. Query image binary content.
  46. Args:
  47. train_id (str): Job ID.
  48. image_path (str): Image path relative to explain job's summary directory.
  49. image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.
  50. Returns:
  51. bytes, image binary.
  52. """
  53. if image_type == "outcome":
  54. sample_id, label, layer = image_path.strip(".jpg").split("_")
  55. layer = int(layer)
  56. job = self.job_manager.get_job(train_id)
  57. samples = job.samples
  58. label_idx = job.labels.index(label)
  59. chosen_sample = samples[int(sample_id)]
  60. original_path_image = chosen_sample['image']
  61. abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
  62. original_path_image)
  63. if self._is_forbidden(abs_image_path):
  64. raise FileSystemPermissionError("Forbidden.")
  65. try:
  66. image = Image.open(abs_image_path)
  67. except FileNotFoundError:
  68. raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  69. except PermissionError:
  70. raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  71. except OSError:
  72. raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")
  73. edit_steps = []
  74. boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
  75. mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]
  76. for box in boxes:
  77. edit_steps.append(EditStep(layer, *box))
  78. image_cp = pil_apply_edit_steps(image, mask, edit_steps)
  79. buffer = io.BytesIO()
  80. image_cp.save(buffer, format=_PNG_FORMAT)
  81. return buffer.getvalue()
  82. abs_image_path = os.path.join(self.job_manager.summary_base_dir,
  83. _clean_train_id_b4_join(train_id),
  84. image_path)
  85. if self._is_forbidden(abs_image_path):
  86. raise FileSystemPermissionError("Forbidden.")
  87. try:
  88. if image_type != "overlay":
  89. # no need to convert
  90. with open(abs_image_path, "rb") as fp:
  91. return fp.read()
  92. image = Image.open(abs_image_path)
  93. if image.mode == _RGBA_MODE:
  94. # It is RGBA already, do not convert.
  95. with open(abs_image_path, "rb") as fp:
  96. return fp.read()
  97. except FileNotFoundError:
  98. raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  99. except PermissionError:
  100. raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  101. except OSError:
  102. raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")
  103. if image.mode == _SINGLE_CHANNEL_MODE:
  104. saliency = np.asarray(image) / _UINT8_MAX
  105. elif image.mode == _RGB_MODE:
  106. saliency = np.asarray(image)
  107. saliency = saliency[:, :, 0] / _UINT8_MAX
  108. else:
  109. raise UnknownError(f"Invalid overlay image mode:{image.mode}.")
  110. saliency_stack = np.empty((saliency.shape[0], saliency.shape[1], 4))
  111. for c in range(3):
  112. saliency_stack[:, :, c] = saliency
  113. rgba = saliency_stack * _SALIENCY_CMAP_HI
  114. rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW
  115. rgba[:, :, 3] = saliency * _UINT8_MAX
  116. overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE)
  117. buffer = io.BytesIO()
  118. overlay.save(buffer, format=_PNG_FORMAT)
  119. return buffer.getvalue()
  120. def _is_forbidden(self, path):
  121. """Check if the path is outside summary base dir."""
  122. base_dir = os.path.realpath(self.job_manager.summary_base_dir)
  123. path = os.path.realpath(path)
  124. return not path.startswith(base_dir)