You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datafile_encap.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Datafile encapsulator."""
  16. import io
  17. import os
  18. import numpy as np
  19. from PIL import Image
  20. from mindinsight.datavisual.common.exceptions import ImageNotExistError
  21. from mindinsight.explainer.common.enums import ImageQueryTypes
  22. from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
  23. from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
  24. from mindinsight.utils.exceptions import FileSystemPermissionError
  25. from mindinsight.utils.exceptions import UnknownError
  26. # Max uint8 value. for converting RGB pixels to [0,1] intensity.
  27. _UINT8_MAX = 255
  28. # Color of low saliency.
  29. _SALIENCY_CMAP_LOW = (55, 25, 86, 255)
  30. # Color of high saliency.
  31. _SALIENCY_CMAP_HI = (255, 255, 0, 255)
  32. # Channel modes.
  33. _SINGLE_CHANNEL_MODE = "L"
  34. _RGBA_MODE = "RGBA"
  35. _RGB_MODE = "RGB"
  36. _PNG_FORMAT = "PNG"
  37. def _clean_train_id_b4_join(train_id):
  38. """Clean train_id before joining to a path."""
  39. if train_id.startswith("./") or train_id.startswith(".\\"):
  40. return train_id[2:]
  41. return train_id
  42. class DatafileEncap(ExplainDataEncap):
  43. """Datafile encapsulator."""
  44. def query_image_binary(self, train_id, image_path, image_type):
  45. """
  46. Query image binary content.
  47. Args:
  48. train_id (str): Job ID.
  49. image_path (str): Image path relative to explain job's summary directory.
  50. image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.
  51. Returns:
  52. bytes, image binary content for UI to demonstrate.
  53. """
  54. if image_type == ImageQueryTypes.OUTCOME.value:
  55. return self._get_hoc_image(image_path, train_id)
  56. abs_image_path = os.path.join(self.job_manager.summary_base_dir,
  57. _clean_train_id_b4_join(train_id),
  58. image_path)
  59. if self._is_forbidden(abs_image_path):
  60. raise FileSystemPermissionError("Forbidden.")
  61. try:
  62. if image_type != ImageQueryTypes.OVERLAY.value:
  63. # no need to convert
  64. with open(abs_image_path, "rb") as fp:
  65. return fp.read()
  66. image = Image.open(abs_image_path)
  67. if image.mode == _RGBA_MODE:
  68. # It is RGBA already, do not convert.
  69. with open(abs_image_path, "rb") as fp:
  70. return fp.read()
  71. except FileNotFoundError:
  72. raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  73. except PermissionError:
  74. raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  75. except OSError:
  76. raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")
  77. if image.mode == _SINGLE_CHANNEL_MODE:
  78. saliency = np.asarray(image) / _UINT8_MAX
  79. elif image.mode == _RGB_MODE:
  80. saliency = np.asarray(image)
  81. saliency = saliency[:, :, 0] / _UINT8_MAX
  82. else:
  83. raise UnknownError(f"Invalid overlay image mode:{image.mode}.")
  84. saliency_stack = np.empty((saliency.shape[0], saliency.shape[1], 4))
  85. for c in range(3):
  86. saliency_stack[:, :, c] = saliency
  87. rgba = saliency_stack * _SALIENCY_CMAP_HI
  88. rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW
  89. rgba[:, :, 3] = saliency * _UINT8_MAX
  90. overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE)
  91. buffer = io.BytesIO()
  92. overlay.save(buffer, format=_PNG_FORMAT)
  93. return buffer.getvalue()
  94. def _is_forbidden(self, path):
  95. """Check if the path is outside summary base dir."""
  96. base_dir = os.path.realpath(self.job_manager.summary_base_dir)
  97. path = os.path.realpath(path)
  98. return not path.startswith(base_dir)
  99. def _get_hoc_image(self, image_path, train_id):
  100. """Get hoc image for image data demonstration in UI."""
  101. sample_id, label, layer = image_path.strip(".jpg").split("_")
  102. layer = int(layer)
  103. job = self.job_manager.get_job(train_id)
  104. samples = job.samples
  105. label_idx = job.labels.index(label)
  106. chosen_sample = samples[int(sample_id)]
  107. original_path_image = chosen_sample['image']
  108. abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
  109. original_path_image)
  110. if self._is_forbidden(abs_image_path):
  111. raise FileSystemPermissionError("Forbidden.")
  112. image_type = ImageQueryTypes.OUTCOME.value
  113. try:
  114. image = Image.open(abs_image_path)
  115. except FileNotFoundError:
  116. raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  117. except PermissionError:
  118. raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
  119. except OSError:
  120. raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")
  121. edit_steps = []
  122. boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
  123. mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]
  124. for box in boxes:
  125. edit_steps.append(EditStep(layer, *box))
  126. image_cp = pil_apply_edit_steps(image, mask, edit_steps)
  127. buffer = io.BytesIO()
  128. image_cp.save(buffer, format=_PNG_FORMAT)
  129. return buffer.getvalue()