Browse Source

!1135 Adapt mindinsight backend to support CV counterfactual explanation

From: @lixiaohui33
Reviewed-by: @ouwenchang
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
9c83dad1f8
15 changed files with 650 additions and 241 deletions
  1. +82
    -34
      mindinsight/backend/explainer/explainer_api.py
  2. +13
    -0
      mindinsight/datavisual/proto_files/mindinsight_summary.proto
  3. +1
    -0
      mindinsight/explainer/common/enums.py
  4. +156
    -0
      mindinsight/explainer/encapsulator/_hoc_pil_apply.py
  5. +43
    -9
      mindinsight/explainer/encapsulator/datafile_encap.py
  6. +126
    -2
      mindinsight/explainer/encapsulator/explain_data_encap.py
  7. +5
    -3
      mindinsight/explainer/encapsulator/explain_job_encap.py
  8. +87
    -0
      mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py
  9. +16
    -88
      mindinsight/explainer/encapsulator/saliency_encap.py
  10. +111
    -98
      mindinsight/explainer/manager/explain_loader.py
  11. +1
    -0
      mindinsight/explainer/manager/explain_manager.py
  12. +3
    -1
      mindinsight/explainer/manager/explain_parser.py
  13. +3
    -1
      tests/ut/explainer/encapsulator/mock_explain_manager.py
  14. +3
    -1
      tests/ut/explainer/encapsulator/test_explain_job_encap.py
  15. +0
    -4
      tests/ut/explainer/manager/test_explain_loader.py

+ 82
- 34
mindinsight/backend/explainer/explainer_api.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
# ============================================================================
"""Explainer restful api."""

import os
import json
import os
import urllib.parse

from flask import Blueprint
@@ -23,16 +23,18 @@ from flask import jsonify
from flask import request

from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER
from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.hierarchical_occlusion_encap import HierarchicalOcclusionEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.utils.exceptions import ParamTypeError
from mindinsight.utils.exceptions import ParamValueError

URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)
@@ -66,6 +68,50 @@ def _read_post_request(post_request):
return body


def _get_query_sample_parameters(data):
"""Get parameter for query."""

train_id = data.get("train_id")
if train_id is None:
raise ParamMissError('train_id')

labels = data.get("labels")
if labels is not None and not isinstance(labels, list):
raise ParamTypeError("labels", (list, None))
if labels:
for item in labels:
if not isinstance(item, str):
raise ParamTypeError("element of labels", str)

limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "")
sorted_type = data.get("sorted_type", "descending")
if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")

prediction_types = data.get("prediction_types")
if prediction_types is not None and not isinstance(prediction_types, list):
raise ParamTypeError("prediction_types", (list, None))
if prediction_types:
for item in prediction_types:
if item not in ['TP', 'FN', 'FP']:
raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.")

query_kwarg = {"train_id": train_id,
"labels": labels,
"limit": limit,
"offset": offset,
"sorted_name": sorted_name,
"sorted_type": sorted_type,
"prediction_types": prediction_types}
return query_kwarg


@BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"])
def query_explain_jobs():
"""Query explain jobs."""
@@ -99,37 +145,40 @@ def query_explain_job():
@BLUEPRINT.route("/explainer/saliency", methods=["POST"])
def query_saliency():
"""Query saliency map related results."""

data = _read_post_request(request)

train_id = data.get("train_id")
if train_id is None:
raise ParamMissError('train_id')

labels = data.get("labels")
query_kwarg = _get_query_sample_parameters(data)
explainers = data.get("explainers")
limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "")
sorted_type = data.get("sorted_type", "descending")
if explainers is not None and not isinstance(explainers, list):
raise ParamTypeError("explainers", (list, None))
if explainers:
for item in explainers:
if not isinstance(item, str):
raise ParamTypeError("element of explainers", str)

if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")
query_kwarg["explainers"] = explainers

encapsulator = SaliencyEncap(
_image_url_formatter,
EXPLAIN_MANAGER)
count, samples = encapsulator.query_saliency_maps(train_id=train_id,
labels=labels,
explainers=explainers,
limit=limit,
offset=offset,
sorted_name=sorted_name,
sorted_type=sorted_type)
count, samples = encapsulator.query_saliency_maps(**query_kwarg)

return jsonify({
"count": count,
"samples": samples
})


@BLUEPRINT.route("/explainer/hoc", methods=["POST"])
def query_hoc():
"""Query hierarchical occlusion related results."""
data = _read_post_request(request)

query_kwargs = _get_query_sample_parameters(data)

encapsulator = HierarchicalOcclusionEncap(
_image_url_formatter,
EXPLAIN_MANAGER)
count, samples = encapsulator.query_hierarchical_occlusion(**query_kwargs)

return jsonify({
"count": count,
@@ -162,8 +211,8 @@ def query_image():
image_type = request.args.get("type")
if image_type is None:
raise ParamMissError("type")
if image_type not in ("original", "overlay"):
raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'")
if image_type not in ("original", "overlay", "outcome"):
raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay' 'outcome'")

encapsulator = DatafileEncap(EXPLAIN_MANAGER)
image = encapsulator.query_image_binary(train_id, image_path, image_type)
@@ -177,6 +226,5 @@ def init_module(app):

Args:
app: the application obj.

"""
app.register_blueprint(BLUEPRINT)

+ 13
- 0
mindinsight/datavisual/proto_files/mindinsight_summary.proto View File

@@ -138,6 +138,17 @@ message Explain {
repeated string benchmark_method = 3;
}

message HocLayer{
optional float prob = 1;
repeated int32 box = 2; // List of repeated x, y, w, h
}

message Hoc {
optional int32 label = 1;
optional string mask = 2;
repeated HocLayer layer = 3;
}

optional int32 sample_id = 1; // The Metadata and sample id must have one fill in
optional string image_path = 2;
repeated int32 ground_truth_label = 3;
@@ -148,4 +159,6 @@ message Explain {

optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end

repeated Hoc hoc = 9; // hierarchical occlusion counterfactual
}

+ 1
- 0
mindinsight/explainer/common/enums.py View File

@@ -40,6 +40,7 @@ class ExplainFieldsEnum(BaseEnum):
GROUND_TRUTH_LABEL = 'ground_truth_label'
INFERENCE = 'inference'
EXPLANATION = 'explanation'
HIERARCHICAL_OCCLUSION = 'hierarchical_occlusion'
STATUS = 'status'




+ 156
- 0
mindinsight/explainer/encapsulator/_hoc_pil_apply.py View File

@@ -0,0 +1,156 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Utility functions for hierarchcial occlusion image generation."""

import re

import PIL
from PIL import ImageDraw, ImageEnhance, ImageFilter

MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$'


class EditStep:
"""
Class that represents an edit step.

Args:
layer (int): Layer index.
x (int): Left pixel coordinate.
y (int): Top pixel coordinate.
width (int): Width in pixels.
height (int): Height in pixels.
"""
def __init__(self,
layer: int,
x: int,
y: int,
width: int,
height: int):
self.layer = layer
self.x = x
self.y = y
self.width = width
self.height = height

def to_coord_box(self):
"""
Convert to pixel coordinate box.

Returns:
tuple[int, int, int, int], tuple of left, top, right, bottom pixel coordinate.
"""
return self.x, self.y, self.x + self.width, self.y + self.height


def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=False):
"""
Apply edit steps on a PIL image.

Args:
image (PIL.Image): The input image in RGB mode.
mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object.
edit_steps (list[EditStep]): Edit steps to be drawn.
by_masking (bool): Whether to use masking method. Default: False.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if isinstance(mask, str):
mask = pil_compile_str_mask(mask, image)
if by_masking:
return _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace)
return _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace)


def pil_compile_str_mask(mask, image):
"""Concert string mask to PIL Image."""
match = re.match(MASK_GAUSSIAN_RE, mask)
if match:
radius = int(match.group(1))
if radius > 0:
image_filter = ImageFilter.GaussianBlur(radius=radius)
mask_image = image.filter(image_filter)
mask_image = ImageEnhance.Brightness(mask_image).enhance(0.7)
mask_image = ImageEnhance.Color(mask_image).enhance(0.0)
return mask_image
raise ValueError(f"Invalid string mask: '{mask}'.")


def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False):
"""
Apply edit steps from unmasking method on a PIL image.

Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if isinstance(mask, PIL.Image.Image):
if inplace:
bg = mask
else:
bg = mask.copy()
else:
if inplace:
raise ValueError('Argument inplace cannot be True when mask is not a PIL Image.')
if isinstance(mask, int):
mask = (mask, mask, mask)
bg = PIL.Image.new(mode="RGB", size=image.size, color=mask)

for step in edit_steps:
box = step.to_coord_box()
cropped = image.crop(box)
bg.paste(cropped, box=box)
return bg


def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False):
"""
Apply edit steps from unmasking method on a PIL image.

Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if not inplace:
image = image.copy()

if isinstance(mask, PIL.Image.Image):
for step in edit_steps:
box = step.to_coord_box()
cropped = mask.crop(box)
image.paste(cropped, box=box)
else:
if isinstance(mask, int):
mask = (mask, mask, mask)
draw = ImageDraw.Draw(image)
for step in edit_steps:
draw.rectangle(step.to_coord_box(), fill=mask)
return image

+ 43
- 9
mindinsight/explainer/encapsulator/datafile_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,16 +14,17 @@
# ============================================================================
"""Datafile encapsulator."""

import os
import io
import os

from PIL import Image
import numpy as np
from PIL import Image

from mindinsight.utils.exceptions import UnknownError
from mindinsight.utils.exceptions import FileSystemPermissionError
from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.utils.exceptions import FileSystemPermissionError
from mindinsight.utils.exceptions import UnknownError

# Max uint8 value. for converting RGB pixels to [0,1] intensity.
_UINT8_MAX = 255
@@ -59,11 +60,44 @@ class DatafileEncap(ExplainDataEncap):
Args:
train_id (str): Job ID.
image_path (str): Image path relative to explain job's summary directory.
image_type (str): Image type, 'original' or 'overlay'.
image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.

Returns:
bytes, image binary.
"""
if image_type == "outcome":
sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()

abs_image_path = os.path.join(self.job_manager.summary_base_dir,
_clean_train_id_b4_join(train_id),
@@ -94,10 +128,10 @@ class DatafileEncap(ExplainDataEncap):
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

if image.mode == _SINGLE_CHANNEL_MODE:
saliency = np.asarray(image)/_UINT8_MAX
saliency = np.asarray(image) / _UINT8_MAX
elif image.mode == _RGB_MODE:
saliency = np.asarray(image)
saliency = saliency[:, :, 0]/_UINT8_MAX
saliency = saliency[:, :, 0] / _UINT8_MAX
else:
raise UnknownError(f"Invalid overlay image mode:{image.mode}.")

@@ -105,7 +139,7 @@ class DatafileEncap(ExplainDataEncap):
for c in range(3):
saliency_stack[:, :, c] = saliency
rgba = saliency_stack * _SALIENCY_CMAP_HI
rgba += (1-saliency_stack) * _SALIENCY_CMAP_LOW
rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW
rgba[:, :, 3] = saliency * _UINT8_MAX

overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE)


+ 126
- 2
mindinsight/explainer/encapsulator/explain_data_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Explain data encapsulator base class."""
"""Common explain data encapsulator base class."""

import copy

from mindinsight.utils.exceptions import ParamValueError


def _sort_key_min_confidence(sample, labels):
"""Samples sort key by the min. confidence."""
min_confidence = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
if inference["confidence"] < min_confidence:
min_confidence = inference["confidence"]
return min_confidence


def _sort_key_max_confidence(sample, labels):
"""Samples sort key by the max. confidence."""
max_confidence = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
if inference["confidence"] > max_confidence:
max_confidence = inference["confidence"]
return max_confidence


def _sort_key_min_confidence_sd(sample, labels):
"""Samples sort key by the min. confidence_sd."""
min_confidence_sd = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
confidence_sd = inference.get("confidence_sd", float("+inf"))
if confidence_sd < min_confidence_sd:
min_confidence_sd = confidence_sd
return min_confidence_sd


def _sort_key_max_confidence_sd(sample, labels):
"""Samples sort key by the max. confidence_sd."""
max_confidence_sd = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
confidence_sd = inference.get("confidence_sd", float("-inf"))
if confidence_sd > max_confidence_sd:
max_confidence_sd = confidence_sd
return max_confidence_sd


class ExplainDataEncap:
@@ -24,3 +74,77 @@ class ExplainDataEncap:
@property
def job_manager(self):
return self._job_manager


class ExplanationEncap(ExplainDataEncap):
"""Base encapsulator for explanation queries."""

def __init__(self, image_url_formatter, *args, **kwargs):
super().__init__(*args, **kwargs)
self._image_url_formatter = image_url_formatter

def _query_samples(self,
job,
labels,
sorted_name,
sorted_type,
prediction_types=None,
query_type="saliency_maps"):
"""
Query samples.

Args:
job (ExplainManager): Explain job to be query from.
labels (list[str]): Label filter.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction type filter.

Returns:
list[dict], samples to be queried.
"""

samples = copy.deepcopy(job.get_all_samples())
samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])]
if labels:
filtered = []
for sample in samples:
infer_labels = [inference["label"] for inference in sample["inferences"]]
for infer_label in infer_labels:
if infer_label in labels:
filtered.append(sample)
break
samples = filtered

if prediction_types and len(prediction_types) < 3:
filtered = []
for sample in samples:
infer_types = [inference["prediction_type"] for inference in sample["inferences"]]
for infer_type in infer_types:
if infer_type in prediction_types:
filtered.append(sample)
break
samples = filtered

reverse = sorted_type == "descending"
if sorted_name == "confidence":
if reverse:
samples.sort(key=lambda x: _sort_key_max_confidence(x, labels), reverse=reverse)
else:
samples.sort(key=lambda x: _sort_key_min_confidence(x, labels), reverse=reverse)
elif sorted_name == "uncertainty":
if not job.uncertainty_enabled:
raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'")
if reverse:
samples.sort(key=lambda x: _sort_key_max_confidence_sd(x, labels), reverse=reverse)
else:
samples.sort(key=lambda x: _sort_key_min_confidence_sd(x, labels), reverse=reverse)
elif sorted_name != "":
raise ParamValueError("sorted_name")
return samples

def _get_image_url(self, train_id, image_path, image_type):
"""Returns image's url."""
if self._image_url_formatter is None:
return image_path
return self._image_url_formatter(train_id, image_path, image_type)

+ 5
- 3
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
# ============================================================================
"""Explain job list encapsulator."""

import copy
from datetime import datetime

from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
@@ -61,6 +60,9 @@ class ExplainJobEncap(ExplainDataEncap):
info["train_id"] = dir_info["relative_path"]
info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT)
info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT)
info["saliency_map"] = dir_info["saliency_map"]
info["hierarchical_occlusion"] = dir_info["hierarchical_occlusion"]

return info

@classmethod
@@ -79,7 +81,7 @@ class ExplainJobEncap(ExplainDataEncap):
"""Convert ExplainJob's meta-data to jsonable info object."""
info = cls._job_2_info(job)
info["sample_count"] = job.sample_count
info["classes"] = copy.deepcopy(job.all_classes)
info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0]
saliency_info = dict()
if job.min_confidence is None:
saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE


+ 87
- 0
mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py View File

@@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hierarchical Occlusion encapsulator."""

from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap


class HierarchicalOcclusionEncap(ExplanationEncap):
"""Hierarchical occlusion encapsulator."""

def query_hierarchical_occlusion(self,
train_id,
labels,
limit,
offset,
sorted_name,
sorted_type,
prediction_types=None
):
"""
Query hierarchical occlusion results.

Args:
train_id (str): Job ID.
labels (list[str]): Label filter.
limit (int): Maximum number of items to be returned.
offset (int): Page offset.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction types filter.

Returns:
tuple[int, list[dict]], total number of samples after filtering and list of sample results.
"""
job = self.job_manager.get_job(train_id)
if job is None:
raise TrainJobNotExistError(train_id)

samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="hoc_layers")
sample_infos = []
obj_offset = offset * limit
count = len(samples)
end = count
if obj_offset + limit < end:
end = obj_offset + limit
for i in range(obj_offset, end):
sample = samples[i]
sample_infos.append(self._touch_sample(sample, job))

return count, sample_infos

def _touch_sample(self, sample, job):
"""
Final edit on single sample info.

Args:
sample (dict): Sample info.
job (ExplainManager): Explain job.

Returns:
dict, the edited sample info.
"""
sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original")
for inference_item in sample_cp["inferences"]:
new_list = []
for idx, hoc_layer in enumerate(inference_item["hoc_layers"]):
hoc_layer["outcome"] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
"outcome")
new_list.append(hoc_layer)
inference_item["hoc_layers"] = new_list
return sample_cp

+ 16
- 88
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,58 +14,13 @@
# ============================================================================
"""Saliency map encapsulator."""

import copy

from mindinsight.utils.exceptions import ParamValueError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap


def _sort_key_min_confidence(sample):
"""Samples sort key by the min. confidence."""
min_confidence = float("+inf")
for inference in sample["inferences"]:
if inference["confidence"] < min_confidence:
min_confidence = inference["confidence"]
return min_confidence


def _sort_key_max_confidence(sample):
"""Samples sort key by the max. confidence."""
max_confidence = float("-inf")
for inference in sample["inferences"]:
if inference["confidence"] > max_confidence:
max_confidence = inference["confidence"]
return max_confidence


def _sort_key_min_confidence_sd(sample):
"""Samples sort key by the min. confidence_sd."""
min_confidence_sd = float("+inf")
for inference in sample["inferences"]:
confidence_sd = inference.get("confidence_sd", float("+inf"))
if confidence_sd < min_confidence_sd:
min_confidence_sd = confidence_sd
return min_confidence_sd


def _sort_key_max_confidence_sd(sample):
"""Samples sort key by the max. confidence_sd."""
max_confidence_sd = float("-inf")
for inference in sample["inferences"]:
confidence_sd = inference.get("confidence_sd", float("-inf"))
if confidence_sd > max_confidence_sd:
max_confidence_sd = confidence_sd
return max_confidence_sd


class SaliencyEncap(ExplainDataEncap):
class SaliencyEncap(ExplanationEncap):
"""Saliency map encapsulator."""

def __init__(self, image_url_formatter, *args, **kwargs):
super().__init__(*args, **kwargs)
self._image_url_formatter = image_url_formatter

def query_saliency_maps(self,
train_id,
labels,
@@ -73,55 +28,32 @@ class SaliencyEncap(ExplainDataEncap):
limit,
offset,
sorted_name,
sorted_type):
sorted_type,
prediction_types=None):
"""
Query saliency maps.
Args:
train_id (str): Job ID.
labels (list[str]): Label filter.
explainers (list[str]): Explainers of saliency maps to be shown.
limit (int): Max. no. of items to be returned.
limit (int): Maximum number of items to be returned.
offset (int): Page offset.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction types filter. Default: None.

Returns:
tuple[int, list[dict]], total no. of samples after filtering and
list of sample result.
tuple[int, list[dict]], total number of samples after filtering and list of sample result.
"""
job = self.job_manager.get_job(train_id)
if job is None:
raise TrainJobNotExistError(train_id)

samples = copy.deepcopy(job.get_all_samples())
if labels:
filtered = []
for sample in samples:
infer_labels = [inference["label"] for inference in sample["inferences"]]
for infer_label in infer_labels:
if infer_label in labels:
filtered.append(sample)
break
samples = filtered

reverse = sorted_type == "descending"
if sorted_name == "confidence":
if reverse:
samples.sort(key=_sort_key_max_confidence, reverse=reverse)
else:
samples.sort(key=_sort_key_min_confidence, reverse=reverse)
elif sorted_name == "uncertainty":
if not job.uncertainty_enabled:
raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'")
if reverse:
samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse)
else:
samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse)
elif sorted_name != "":
raise ParamValueError("sorted_name")
samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="saliency_maps")

sample_infos = []
obj_offset = offset*limit
obj_offset = offset * limit
count = len(samples)
end = count
if obj_offset + limit < end:
@@ -139,11 +71,13 @@ class SaliencyEncap(ExplainDataEncap):
sample (dict): Sample info.
job (ExplainJob): Explain job.
explainers (list[str]): Explainer names.

Returns:
dict, the edited sample info.
"""
sample["image"] = self._get_image_url(job.train_id, sample['image'], "original")
for inference in sample["inferences"]:
sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original")
for inference in sample_cp["inferences"]:
new_list = []
for saliency_map in inference["saliency_maps"]:
if explainers and saliency_map["explainer"] not in explainers:
@@ -151,10 +85,4 @@ class SaliencyEncap(ExplainDataEncap):
saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay")
new_list.append(saliency_map)
inference["saliency_maps"] = new_list
return sample

def _get_image_url(self, train_id, image_path, image_type):
"""Returns image's url."""
if self._image_url_formatter is None:
return image_path
return self._image_url_formatter(train_id, image_path, image_type)
return sample_cp

+ 111
- 98
mindinsight/explainer/manager/explain_loader.py View File

@@ -14,14 +14,13 @@
# ============================================================================
"""ExplainLoader."""

from collections import defaultdict
from enum import Enum

import math
import os
import re
import threading
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Dict, Iterable, List, Optional, Union

from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
@@ -44,6 +43,7 @@ _SAMPLE_FIELD_NAMES = [
ExplainFieldsEnum.GROUND_TRUTH_LABEL,
ExplainFieldsEnum.INFERENCE,
ExplainFieldsEnum.EXPLANATION,
ExplainFieldsEnum.HIERARCHICAL_OCCLUSION
]


@@ -75,10 +75,10 @@ class ExplainLoader:
'create_time': os.stat(summary_dir).st_ctime,
'update_time': os.stat(summary_dir).st_mtime,
'query_time': os.stat(summary_dir).st_ctime,
'uncertainty_enabled': False,
'uncertainty_enabled': False
}
self._samples = defaultdict(dict)
self._metadata = {'explainers': [], 'metrics': [], 'labels': []}
self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5}
self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)}

self._status = _LoaderStatus.STOP.value
@@ -91,25 +91,19 @@ class ExplainLoader:

Returns:
list[dict], a list of dict, each dict contains:
- id (int): label id
- label (str): label name
- sample_count (int): number of samples for each label

- id (int): Label id.
- label (str): Label name.
- sample_count (int): Number of samples for each label.
"""
sample_count_per_label = defaultdict(int)
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
for label in sample['ground_truth_label']:
for sample in self._samples.values():
if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')):
for label in set(sample['ground_truth_label'] + sample['predicted_label']):
sample_count_per_label[label] += 1

all_classes_return = []
for label_id, label_name in enumerate(self._metadata['labels']):
single_info = {
'id': label_id,
'label': label_name,
'sample_count': sample_count_per_label[label_id]
}
all_classes_return.append(single_info)
all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]}
for label_id, label_name in enumerate(self._metadata['labels'])]
return all_classes_return

@property
@@ -166,19 +160,23 @@ class ExplainLoader:

Returns:
list[dict], A list of evaluation results of each explainer. Each item contains:

- explainer (str): Name of evaluated explainer.
- evaluations (list[dict]): A list of evaluation results by different metrics.
- class_scores (list[dict]): A list of evaluation results on different labels.

Each item in the evaluations contains:

- metric (str): name of metric method
- score (float): evaluation result

Each item in the class_scores contains:

- label (str): Name of label
- evaluations (list[dict]): A list of evaluation results on different labels by different metrics.

Each item in evaluations contains:

- metric (str): Name of metric method
- score (float): Evaluation scores of explainer on specific label by the metric.
"""
@@ -215,7 +213,7 @@ class ExplainLoader:
@property
def min_confidence(self) -> Optional[float]:
"""Return minimum confidence used to filter the predicted labels."""
return None
return self._metadata['min_confidence']

@property
def sample_count(self) -> int:
@@ -227,19 +225,17 @@ class ExplainLoader:

Return:
int, total number of available samples in the loading job.

"""
sample_count = 0
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
for sample in self._samples.values():
if sample.get('image', False):
sample_count += 1
return sample_count

@property
def samples(self) -> List[Dict]:
"""Return the information of all samples in the job."""
return self.get_all_samples()
return self._samples

@property
def train_id(self) -> str:
@@ -298,6 +294,7 @@ class ExplainLoader:
self._clear_job()
if event_dict:
self._import_data_from_event(event_dict)
self._reform_sample_info()

@property
def status(self):
@@ -317,59 +314,19 @@ class ExplainLoader:

def get_all_samples(self) -> List[Dict]:
"""
Return a list of sample information cachced in the explain job
Return a list of sample information cached in the explain job

Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:

- id (int): sample id
- name (str): basename of image
- labels (list[str]): list of labels
- inferences list[dict])
- id (int): Sample id.
- name (str): Basename of image.
- inferences (list[dict]): List of inferences for all labels.
"""
returned_samples = []
samples_copy = self._samples.copy()
for sample_id, sample_info in samples_copy.items():
if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False):
continue
returned_sample = {
'id': sample_id,
'name': str(sample_id),
'image': sample_info['image'],
'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']],
}

# Check whether the sample has valid label-prob pairs.
if not ExplainLoader._is_inference_valid(sample_info):
continue

inferences = {}
for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob'] + sample_info['predicted_prob']):
inferences[label] = {
'label': self._metadata['labels'][label],
'confidence': _round(prob),
'saliency_maps': []
}

if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']:
for label, std, low, high in zip(
sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'],
sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'],
sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi']
):
inferences[label]['confidence_sd'] = _round(std)
inferences[label]['confidence_itl95'] = [_round(low), _round(high)]

for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
for label, heatmap_path in label_heatmap_path_dict.items():
if label in inferences:
inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path})

returned_sample['inferences'] = list(inferences.values())
returned_samples.append(returned_sample)
returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'],
'inferences': list(info['inferences'].values())} for sample_id, info in
self._samples.items() if info.get('image', False)]
return returned_samples

def _import_data_from_event(self, event_dict: Dict):
@@ -461,30 +418,29 @@ class ExplainLoader:
- predicted_probs (list[int]): A list of confidences w.r.t the predicted labels.
- explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary
of saliency maps. The structure of explanations demonstrates below:

{
explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
...
}
- hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary:
{label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}],
label_2:
}
"""
if getattr(sample, 'sample_id', None) is None:
raise ParamValueError('sample_event has no sample_id')
sample_id = sample.sample_id
samples_copy = self._samples.copy()
if sample_id not in samples_copy:
if sample_id not in self._samples:
self._samples[sample_id] = {
'id': sample_id,
'name': str(sample_id),
'image': sample.image_path,
'ground_truth_label': [],
'ground_truth_prob': [],
'ground_truth_prob_sd': [],
'ground_truth_prob_itl95_low': [],
'ground_truth_prob_itl95_hi': [],
'predicted_label': [],
'predicted_prob': [],
'predicted_prob_sd': [],
'predicted_prob_itl95_low': [],
'predicted_prob_itl95_hi': [],
'explanation': defaultdict(dict)
'inferences': defaultdict(dict),
'explanation': defaultdict(dict),
'hierarchical_occlusion': defaultdict(dict)
}

if sample.image_path:
@@ -492,27 +448,66 @@ class ExplainLoader:

for tag in _SAMPLE_FIELD_NAMES:
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
if not self._samples[sample_id]['ground_truth_label']:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id)
else:
elif tag == ExplainFieldsEnum.EXPLANATION:
self._import_explanation_from_event(sample, sample_id)
elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION:
self._import_hoc_from_event(sample, sample_id)

def _reform_sample_info(self):
"""Reform the sample info."""
for _, sample_info in self._samples.items():
inferences = sample_info['inferences']
res_dict = defaultdict(list)
for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
for label, heatmap_path in label_heatmap_path_dict.items():
res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path})

for label, item in inferences.items():
item['saliency_maps'] = res_dict[label]

for label, item in sample_info['hierarchical_occlusion'].items():
inferences[label]['hoc_layers'] = item['hoc_layers']

def _import_inference_from_event(self, event, sample_id):
"""Parse the inference event."""
inference = event.inference
self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob))
self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd))
self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low))
self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi))
self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob))
self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd))
self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low))
self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi))

if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']:
if inference.ground_truth_prob_sd or inference.predicted_prob_sd:
self._loader_info['uncertainty_enabled'] = True
if not self._samples[sample_id]['predicted_label']:
self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
if not self._samples[sample_id]['inferences']:
inferences = {}
for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label),
list(inference.ground_truth_prob) + list(inference.predicted_prob)):
inferences[label] = {
'label': self._metadata['labels'][label],
'confidence': _round(prob),
'saliency_maps': [],
'hoc_layers': {},
}
if not event.ground_truth_label:
inferences[label]['prediction_type'] = None
else:
if prob < self.min_confidence:
inferences[label]['prediction_type'] = 'FN'
elif label in event.ground_truth_label:
inferences[label]['prediction_type'] = 'TP'
else:
inferences[label]['prediction_type'] = 'FP'
if self._loader_info['uncertainty_enabled']:
for label, std, low, high in zip(
list(event.ground_truth_label) + list(inference.predicted_label),
list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd),
list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low),
list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)):
inferences[label]['confidence_sd'] = _round(std)
inferences[label]['confidence_itl95'] = [_round(low), _round(high)]

self._samples[sample_id]['inferences'] = inferences

def _import_explanation_from_event(self, event, sample_id):
"""Parse the explanation event."""
@@ -525,6 +520,24 @@ class ExplainLoader:
label = explanation_item.label
sample_explanation[explainer][label] = explanation_item.heatmap_path

def _import_hoc_from_event(self, event, sample_id):
"""Parse the mango event."""
sample_hoc = self._samples[sample_id]['hierarchical_occlusion']
if event.hierarchical_occlusion:
for hoc_item in event.hierarchical_occlusion:
label = hoc_item.label
sample_hoc[label] = {}
sample_hoc[label]['label'] = label
sample_hoc[label]['mask'] = hoc_item.mask
sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence']
sample_hoc[label]['hoc_layers'] = []
for hoc_layer in hoc_item.layer:
sample_hoc_dict = {'confidence': hoc_layer.prob}
box_lst = list(hoc_layer.box)
box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)]
sample_hoc_dict['boxes'] = box
sample_hoc[label]['hoc_layers'].append(sample_hoc_dict)

def _clear_job(self):
"""Clear the cached data and update the time info of the loader."""
self._samples.clear()


+ 1
- 0
mindinsight/explainer/manager/explain_manager.py View File

@@ -114,6 +114,7 @@ class ExplainManager:
Returns:
tuple[total, directories], total indicates the overall number of explain directories and directories
indicate list of summary directory info including the following attributes.

- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.


+ 3
- 1
mindinsight/explainer/manager/explain_parser.py View File

@@ -30,6 +30,7 @@ from mindinsight.utils.exceptions import UnknownError
HEADER_SIZE = 8
CRC_STR_SIZE = 4
MAX_EVENT_STRING = 500000000

BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status'])
MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status'])
InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob',
@@ -42,7 +43,7 @@ InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob',
'predicted_prob_itl95_low',
'predicted_prob_itl95_hi'])
SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference',
'explanation', 'status'])
'explanation', 'hierarchical_occlusion', 'status'])


class ExplainParser(_SummaryParser):
@@ -193,6 +194,7 @@ class ExplainParser(_SummaryParser):
ground_truth_label=tensor_event_value.ground_truth_label,
inference=inference,
explanation=tensor_event_value.explanation,
hierarchical_occlusion=tensor_event_value.hoc,
status=tensor_event_value.status
)
return sample_data


+ 3
- 1
tests/ut/explainer/encapsulator/mock_explain_manager.py View File

@@ -105,7 +105,9 @@ class MockExplainManager:
{
"relative_path": "./mock_job_1",
"create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT),
"update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT)
"update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT),
"saliency_map": True,
"hierarchical_occlusion": True
}
]
return 1, job_list


+ 3
- 1
tests/ut/explainer/encapsulator/test_explain_job_encap.py View File

@@ -31,7 +31,9 @@ class TestExplainJobEncap:
{
"train_id": "./mock_job_1",
"create_time": "2020-10-01 20:21:23",
"update_time": "2020-10-01 20:21:23"
"update_time": "2020-10-01 20:21:23",
"saliency_map": True,
"hierarchical_occlusion": True
}
])
assert job_list == expected_result


+ 0
- 4
tests/ut/explainer/manager/test_explain_loader.py View File

@@ -24,10 +24,6 @@ from mindinsight.explainer.manager.explain_loader import _LoaderStatus
from mindinsight.explainer.manager.explain_parser import ExplainParser


def abc():
FileHandler.is_file('aaa')
print('after')


class TestExplainLoader:
"""Test explain loader class."""


Loading…
Cancel
Save