|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Datafile encapsulator."""
-
- import os
- import io
-
- from PIL import Image
- from PIL import UnidentifiedImageError
- import numpy as np
-
- from mindinsight.utils.exceptions import UnknownError
- from mindinsight.utils.exceptions import FileSystemPermissionError
- from mindinsight.datavisual.common.exceptions import ImageNotExistError
- from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
-
- # Max uint8 value. for converting RGB pixels to [0,1] intensity.
- _UINT8_MAX = 255
-
- # Color of low saliency.
- _SALIENCY_CMAP_LOW = (55, 25, 86, 255)
-
- # Color of high saliency.
- _SALIENCY_CMAP_HI = (255, 255, 0, 255)
-
- # Channel modes.
- _SINGLE_CHANNEL_MODE = "L"
- _RGBA_MODE = "RGBA"
- _RGB_MODE = "RGB"
-
- _PNG_FORMAT = "PNG"
-
-
- def _clean_train_id_b4_join(train_id):
- """Clean train_id before joining to a path."""
- if train_id.startswith("./") or train_id.startswith(".\\"):
- return train_id[2:]
- return train_id
-
-
- class DatafileEncap(ExplainDataEncap):
- """Datafile encapsulator."""
-
- def query_image_binary(self, train_id, image_path, image_type):
- """
- Query image binary content.
-
- Args:
- train_id (str): Job ID.
- image_path (str): Image path relative to explain job's summary directory.
- image_type (str): Image type, 'original' or 'overlay'.
-
- Returns:
- bytes, image binary.
- """
-
- abs_image_path = os.path.join(self.job_manager.summary_base_dir,
- _clean_train_id_b4_join(train_id),
- image_path)
-
- if self._is_forbidden(abs_image_path):
- raise FileSystemPermissionError("Forbidden.")
-
- try:
-
- if image_type != "overlay":
- # no need to convert
- with open(abs_image_path, "rb") as fp:
- return fp.read()
-
- image = Image.open(abs_image_path)
-
- if image.mode == _RGBA_MODE:
- # It is RGBA already, do not convert.
- with open(abs_image_path, "rb") as fp:
- return fp.read()
-
- except FileNotFoundError:
- raise ImageNotExistError(image_path)
- except PermissionError:
- raise FileSystemPermissionError(image_path)
- except UnidentifiedImageError:
- raise UnknownError(f"Invalid image file: {image_path}")
-
- if image.mode == _SINGLE_CHANNEL_MODE:
- saliency = np.asarray(image)/_UINT8_MAX
- elif image.mode == _RGB_MODE:
- saliency = np.asarray(image)
- saliency = saliency[:, :, 0]/_UINT8_MAX
- else:
- raise UnknownError(f"Invalid overlay image mode:{image.mode}.")
-
- rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
- for c in range(3):
- rgba[:, :, c] = saliency
- rgba = rgba * _SALIENCY_CMAP_HI + (1-rgba) * _SALIENCY_CMAP_LOW
- rgba[:, :, 3] = saliency * _UINT8_MAX
-
- overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE)
- buffer = io.BytesIO()
- overlay.save(buffer, format=_PNG_FORMAT)
-
- return buffer.getvalue()
-
- def _is_forbidden(self, path):
- """Check if the path is outside summary base dir."""
- base_dir = os.path.realpath(self.job_manager.summary_base_dir)
- path = os.path.realpath(path)
- return not path.startswith(base_dir)
|