Browse Source

CI fix: shorten long functions. And format API comments.

tags/v1.2.0-rc1
lixiaohui 4 years ago
parent
commit
896129bb37
12 changed files with 261 additions and 124 deletions
  1. +124
    -21
      mindinsight/backend/explainer/explainer_api.py
  2. +5
    -10
      mindinsight/datavisual/data_transform/summary_watcher.py
  3. +14
    -1
      mindinsight/explainer/common/enums.py
  4. +3
    -3
      mindinsight/explainer/encapsulator/_hoc_pil_apply.py
  5. +43
    -35
      mindinsight/explainer/encapsulator/datafile_encap.py
  6. +8
    -13
      mindinsight/explainer/encapsulator/explain_data_encap.py
  7. +4
    -2
      mindinsight/explainer/encapsulator/explain_job_encap.py
  8. +9
    -5
      mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py
  9. +10
    -4
      mindinsight/explainer/encapsulator/saliency_encap.py
  10. +5
    -6
      mindinsight/explainer/manager/explain_loader.py
  11. +15
    -13
      mindinsight/explainer/manager/explain_manager.py
  12. +21
    -11
      mindinsight/explainer/manager/explain_parser.py

+ 124
- 21
mindinsight/backend/explainer/explainer_api.py View File

@@ -40,8 +40,52 @@ URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)




def _validate_type(param, name, expected_types):
"""
Common function to validate type.

Args:
param (object): Parameter to be validated.
name (str): Name of the parameter.
expected_types (type, tuple[type]): Expected type(s) of param.

Raises:
ParamTypeError: When param is not an instance of expected_types.
"""

if not isinstance(param, expected_types):
raise ParamTypeError(name, expected_types)


def _validate_value(param, name, expected_values):
"""
Common function to validate values of param.

Args:
param (object): Parameter to be validated.
name (str): Name of the parameter.
expected_values (tuple) : Expected values of param.

Raises:
ParamValueError: When param is not in expected_values.
"""

if param not in expected_values:
raise ParamValueError(f"Valid options for {name} are {expected_values}, but got {param}.")


def _image_url_formatter(train_id, image_path, image_type): def _image_url_formatter(train_id, image_path, image_type):
"""Returns image url."""
"""
Returns image url.

Args:
train_id (str): Id that specifies explain job.
image_path (str): Local path or unique string that specifies the image for query.
image_type (str): Image query type.

Returns:
str, url string for image query.
"""
data = { data = {
"train_id": train_id, "train_id": train_id,
"path": image_path, "path": image_path,
@@ -69,38 +113,48 @@ def _read_post_request(post_request):




def _get_query_sample_parameters(data): def _get_query_sample_parameters(data):
"""Get parameter for query."""
"""
Get parameter for query.

Args:
data (dict): Dict that contains request info.

Returns:
dict, key-value pairs to call backend query functions.

Raises:
ParamMissError: If train_id info is not in the request.
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""


train_id = data.get("train_id") train_id = data.get("train_id")
if train_id is None: if train_id is None:
raise ParamMissError('train_id') raise ParamMissError('train_id')


labels = data.get("labels") labels = data.get("labels")
if labels is not None and not isinstance(labels, list):
raise ParamTypeError("labels", (list, None))
if labels is not None:
_validate_type(labels, "labels", list)
if labels: if labels:
for item in labels: for item in labels:
if not isinstance(item, str):
raise ParamTypeError("element of labels", str)
_validate_type(item, "element of labels", str)


limit = data.get("limit", 10) limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100) limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0) offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset) offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "") sorted_name = data.get("sorted_name", "")
_validate_value(sorted_name, "sorted_name", ('', 'confidence', 'uncertainty'))

sorted_type = data.get("sorted_type", "descending") sorted_type = data.get("sorted_type", "descending")
if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")
_validate_value(sorted_type, "sorted_type", ("ascending", "descending"))


prediction_types = data.get("prediction_types") prediction_types = data.get("prediction_types")
if prediction_types is not None and not isinstance(prediction_types, list):
raise ParamTypeError("prediction_types", (list, None))
if prediction_types is not None:
_validate_type(prediction_types, "element of labels", list)
if prediction_types: if prediction_types:
for item in prediction_types: for item in prediction_types:
if item not in ['TP', 'FN', 'FP']:
raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.")
_validate_value(item, "element of prediction_types", ('TP', 'FN', 'FP'))


query_kwarg = {"train_id": train_id, query_kwarg = {"train_id": train_id,
"labels": labels, "labels": labels,
@@ -114,7 +168,17 @@ def _get_query_sample_parameters(data):


@BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"])
def query_explain_jobs(): def query_explain_jobs():
"""Query explain jobs."""
"""
Query explain jobs.

Returns:
Response, contains dict that stores base directory, total number of jobs and their detailed job metadata.

Raises:
ParamMissError: If train_id info is not in the request.
ParamTypeError: If one of (offset, limit) is not integer in the request.
ParamValueError: If one of (offset, limit) does not have the expected value in the request.
"""
offset = request.args.get("offset", default=0) offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10) limit = request.args.get("limit", default=10)
offset = Validation.check_offset(offset=offset) offset = Validation.check_offset(offset=offset)
@@ -132,7 +196,15 @@ def query_explain_jobs():


@BLUEPRINT.route("/explainer/explain-job", methods=["GET"]) @BLUEPRINT.route("/explainer/explain-job", methods=["GET"])
def query_explain_job(): def query_explain_job():
"""Query explain job meta-data."""
"""
Query explain job meta-data.

Returns:
Response, contains dict that stores metadata of the requested job.

Raises:
ParamMissError: If train_id info is not in the request.
"""
train_id = get_train_id(request) train_id = get_train_id(request)
if train_id is None: if train_id is None:
raise ParamMissError("train_id") raise ParamMissError("train_id")
@@ -143,7 +215,16 @@ def query_explain_job():


@BLUEPRINT.route("/explainer/saliency", methods=["POST"]) @BLUEPRINT.route("/explainer/saliency", methods=["POST"])
def query_saliency(): def query_saliency():
"""Query saliency map related results."""
"""
Query saliency map related results.

Returns:
Response, contains dict that stores number of samples and the detailed sample info.

Raises:
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""
data = _read_post_request(request) data = _read_post_request(request)
query_kwarg = _get_query_sample_parameters(data) query_kwarg = _get_query_sample_parameters(data)
explainers = data.get("explainers") explainers = data.get("explainers")
@@ -169,7 +250,16 @@ def query_saliency():


@BLUEPRINT.route("/explainer/hoc", methods=["POST"]) @BLUEPRINT.route("/explainer/hoc", methods=["POST"])
def query_hoc(): def query_hoc():
"""Query hierarchical occlusion related results."""
"""
Query hierarchical occlusion related results.

Returns:
Response, contains dict that stores number of samples and the detailed sample info.

Raises:
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""
data = _read_post_request(request) data = _read_post_request(request)


query_kwargs = _get_query_sample_parameters(data) query_kwargs = _get_query_sample_parameters(data)
@@ -193,7 +283,15 @@ def query_hoc():


@BLUEPRINT.route("/explainer/evaluation", methods=["GET"]) @BLUEPRINT.route("/explainer/evaluation", methods=["GET"])
def query_evaluation(): def query_evaluation():
"""Query saliency explainer evaluation scores."""
"""
Query saliency explainer evaluation scores.

Returns:
Response, contains dict that stores evaluation scores.

Raises:
ParamMissError: If train_id info is not in the request.
"""
train_id = get_train_id(request) train_id = get_train_id(request)
if train_id is None: if train_id is None:
raise ParamMissError("train_id") raise ParamMissError("train_id")
@@ -206,7 +304,12 @@ def query_evaluation():


@BLUEPRINT.route("/explainer/image", methods=["GET"]) @BLUEPRINT.route("/explainer/image", methods=["GET"])
def query_image(): def query_image():
"""Query image."""
"""
Query image.

Returns:
bytes, image binary content for UI to demonstrate.
"""
train_id = get_train_id(request) train_id = get_train_id(request)
if train_id is None: if train_id is None:
raise ParamMissError("train_id") raise ParamMissError("train_id")
@@ -230,6 +333,6 @@ def init_module(app):
Init module entry. Init module entry.


Args: Args:
app: the application obj.
app (flask.app): The application obj.
""" """
app.register_blueprint(BLUEPRINT) app.register_blueprint(BLUEPRINT)

+ 5
- 10
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -220,10 +220,7 @@ class SummaryWatcher:
summary_dict[relative_path].update(job_dict) summary_dict[relative_path].update(job_dict)
if summary_dict[relative_path]['create_time'] < ctime: if summary_dict[relative_path]['create_time'] < ctime:
summary_dict[relative_path].update({
'create_time': ctime,
'update_time': mtime,
})
summary_dict[relative_path].update({'create_time': ctime, 'update_time': mtime})
job_dict = _get_explain_job_info(summary_base_dir, relative_path, timestamp) job_dict = _get_explain_job_info(summary_base_dir, relative_path, timestamp)
summary_dict[relative_path].update(job_dict) summary_dict[relative_path].update(job_dict)
@@ -243,12 +240,10 @@ class SummaryWatcher:
if not is_find: if not is_find:
return return
profiler = {
'directory': os.path.join('.', entry.name),
'create_time': ctime,
'update_time': mtime,
"profiler_type": profiler_type
}
profiler = {'directory': os.path.join('.', entry.name),
'create_time': ctime,
'update_time': mtime,
"profiler_type": profiler_type}
if relative_path in summary_dict: if relative_path in summary_dict:
summary_dict[relative_path]['profiler'] = profiler summary_dict[relative_path]['profiler'] = profiler


+ 14
- 1
mindinsight/explainer/common/enums.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -50,3 +50,16 @@ class CacheStatus(enum.Enum):
NOT_IN_CACHE = "NOT_IN_CACHE" NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING" CACHING = "CACHING"
CACHED = "CACHED" CACHED = "CACHED"


class ExplanationKeys(enum.Enum):
"""Query type enums."""
HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose
SALIENCY = "saliency_maps"


class ImageQueryTypes(enum.Enum):
"""Image query type enums."""
ORIGINAL = 'original' # Query for the original image
OUTCOME = 'outcome' # Query for outcome of HOC explanation
OVERLAY = 'overlay' # Query for saliency maps overlay

+ 3
- 3
mindinsight/explainer/encapsulator/_hoc_pil_apply.py View File

@@ -63,7 +63,7 @@ def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=Fals
Args: Args:
image (PIL.Image): The input image in RGB mode. image (PIL.Image): The input image in RGB mode.
mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object.
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], an RBG tuple or a PIL Image object.
edit_steps (list[EditStep]): Edit steps to be drawn. edit_steps (list[EditStep]): Edit steps to be drawn.
by_masking (bool): Whether to use masking method. Default: False. by_masking (bool): Whether to use masking method. Default: False.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image. inplace (bool): True to draw on the input image, otherwise draw on a cloned image.
@@ -99,7 +99,7 @@ def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False):
Args: Args:
image (PIL.Image): The input image. image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
scale intensity [0, 255], an RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn. edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image. inplace (bool): True to draw on the input image, otherwise draw on a cloned image.


@@ -132,7 +132,7 @@ def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False):
Args: Args:
image (PIL.Image): The input image. image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
scale intensity [0, 255], an RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn. edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image. inplace (bool): True to draw on the input image, otherwise draw on a cloned image.




+ 43
- 35
mindinsight/explainer/encapsulator/datafile_encap.py View File

@@ -21,6 +21,7 @@ import numpy as np
from PIL import Image from PIL import Image


from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.explainer.common.enums import ImageQueryTypes
from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.utils.exceptions import FileSystemPermissionError from mindinsight.utils.exceptions import FileSystemPermissionError
@@ -63,41 +64,10 @@ class DatafileEncap(ExplainDataEncap):
image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.


Returns: Returns:
bytes, image binary.
bytes, image binary content for UI to demonstrate.
""" """
if image_type == "outcome":
sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()
if image_type == ImageQueryTypes.OUTCOME.value:
return self._get_hoc_image(image_path, train_id)


abs_image_path = os.path.join(self.job_manager.summary_base_dir, abs_image_path = os.path.join(self.job_manager.summary_base_dir,
_clean_train_id_b4_join(train_id), _clean_train_id_b4_join(train_id),
@@ -108,7 +78,7 @@ class DatafileEncap(ExplainDataEncap):


try: try:


if image_type != "overlay":
if image_type != ImageQueryTypes.OVERLAY.value:
# no need to convert # no need to convert
with open(abs_image_path, "rb") as fp: with open(abs_image_path, "rb") as fp:
return fp.read() return fp.read()
@@ -153,3 +123,41 @@ class DatafileEncap(ExplainDataEncap):
base_dir = os.path.realpath(self.job_manager.summary_base_dir) base_dir = os.path.realpath(self.job_manager.summary_base_dir)
path = os.path.realpath(path) path = os.path.realpath(path)
return not path.startswith(base_dir) return not path.startswith(base_dir)

def _get_hoc_image(self, image_path, train_id):
"""Get hoc image for image data demonstration in UI."""

sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")

image_type = ImageQueryTypes.OUTCOME.value
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()

+ 8
- 13
mindinsight/explainer/encapsulator/explain_data_encap.py View File

@@ -15,13 +15,13 @@
"""Common explain data encapsulator base class.""" """Common explain data encapsulator base class."""


import copy import copy
from enum import Enum


from mindinsight.explainer.common.enums import ExplanationKeys
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError




def _sort_key_min_confidence(sample, labels): def _sort_key_min_confidence(sample, labels):
"""Samples sort key by the min. confidence."""
"""Samples sort key by the minimum confidence."""
min_confidence = float("+inf") min_confidence = float("+inf")
for inference in sample["inferences"]: for inference in sample["inferences"]:
if labels and inference["label"] not in labels: if labels and inference["label"] not in labels:
@@ -32,7 +32,7 @@ def _sort_key_min_confidence(sample, labels):




def _sort_key_max_confidence(sample, labels): def _sort_key_max_confidence(sample, labels):
"""Samples sort key by the max. confidence."""
"""Samples sort key by the maximum confidence."""
max_confidence = float("-inf") max_confidence = float("-inf")
for inference in sample["inferences"]: for inference in sample["inferences"]:
if labels and inference["label"] not in labels: if labels and inference["label"] not in labels:
@@ -43,7 +43,7 @@ def _sort_key_max_confidence(sample, labels):




def _sort_key_min_confidence_sd(sample, labels): def _sort_key_min_confidence_sd(sample, labels):
"""Samples sort key by the min. confidence_sd."""
"""Samples sort key by the minimum confidence_sd."""
min_confidence_sd = float("+inf") min_confidence_sd = float("+inf")
for inference in sample["inferences"]: for inference in sample["inferences"]:
if labels and inference["label"] not in labels: if labels and inference["label"] not in labels:
@@ -55,7 +55,7 @@ def _sort_key_min_confidence_sd(sample, labels):




def _sort_key_max_confidence_sd(sample, labels): def _sort_key_max_confidence_sd(sample, labels):
"""Samples sort key by the max. confidence_sd."""
"""Samples sort key by the maximum confidence_sd."""
max_confidence_sd = float("-inf") max_confidence_sd = float("-inf")
for inference in sample["inferences"]: for inference in sample["inferences"]:
if labels and inference["label"] not in labels: if labels and inference["label"] not in labels:
@@ -65,11 +65,6 @@ def _sort_key_max_confidence_sd(sample, labels):
max_confidence_sd = confidence_sd max_confidence_sd = confidence_sd
return max_confidence_sd return max_confidence_sd


class ExplanationKeys(Enum):
"""Query type enums."""
HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose
SALIENCY = "saliency_maps"



class ExplainDataEncap: class ExplainDataEncap:
"""Explain data encapsulator base class.""" """Explain data encapsulator base class."""
@@ -105,9 +100,9 @@ class ExplanationEncap(ExplainDataEncap):
sorted_name (str): Field to be sorted. sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'. sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction type filter. prediction_types (list[str]): Prediction type filter.
drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without
hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations
will be filtered out.
drop_type (str, None): When it is None, all data will be kept. When it is 'hoc_layers', samples without
hoc explanations will be drop out. When it is 'saliency_maps', samples without saliency explanations
will be drop out.


Returns: Returns:
list[dict], samples to be queried. list[dict], samples to be queried.


+ 4
- 2
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -29,11 +29,13 @@ class ExplainJobEncap(ExplainDataEncap):
def query_explain_jobs(self, offset, limit): def query_explain_jobs(self, offset, limit):
""" """
Query explain job list. Query explain job list.

Args: Args:
offset (int): Page offset. offset (int): Page offset.
limit (int): Max. no. of items to be returned.
limit (int): Maximum number of items to be returned.

Returns: Returns:
tuple[int, list[Dict]], total no. of jobs and job list.
tuple[int, list[Dict]], total number of jobs and job list.
""" """
total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit) total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit)
job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos] job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos]


+ 9
- 5
mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py View File

@@ -15,7 +15,8 @@
"""Hierarchical Occlusion encapsulator.""" """Hierarchical Occlusion encapsulator."""


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys
from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap




class HierarchicalOcclusionEncap(ExplanationEncap): class HierarchicalOcclusionEncap(ExplanationEncap):
@@ -81,7 +82,10 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
Returns: Returns:
dict, the edited sample info. dict, the edited sample info.
""" """
sample["image"] = self._get_image_url(job.train_id, sample["image"], "original")
original = ImageQueryTypes.ORIGINAL.value
outcome = ImageQueryTypes.OUTCOME.value

sample["image"] = self._get_image_url(job.train_id, sample["image"], original)
inferences = sample["inferences"] inferences = sample["inferences"]
i = 0 # init index for while loop i = 0 # init index for while loop
while i < len(inferences): while i < len(inferences):
@@ -91,9 +95,9 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
continue continue
new_list = [] new_list = []
for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]): for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]):
hoc_layer["outcome"] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
"outcome")
hoc_layer[outcome] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
outcome)
new_list.append(hoc_layer) new_list.append(hoc_layer)
inference_item[ExplanationKeys.HOC.value] = new_list inference_item[ExplanationKeys.HOC.value] = new_list
i += 1 i += 1


+ 10
- 4
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -15,7 +15,8 @@
"""Saliency map encapsulator.""" """Saliency map encapsulator."""


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys
from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap




class SaliencyEncap(ExplanationEncap): class SaliencyEncap(ExplanationEncap):
@@ -32,6 +33,7 @@ class SaliencyEncap(ExplanationEncap):
prediction_types=None): prediction_types=None):
""" """
Query saliency maps. Query saliency maps.

Args: Args:
train_id (str): Job ID. train_id (str): Job ID.
labels (list[str]): Label filter. labels (list[str]): Label filter.
@@ -65,7 +67,8 @@ class SaliencyEncap(ExplanationEncap):


def _touch_sample(self, sample, job, explainers): def _touch_sample(self, sample, job, explainers):
""" """
Final editing the sample info.
Final edit on single sample info.

Args: Args:
sample (dict): Sample info. sample (dict): Sample info.
job (ExplainJob): Explain job. job (ExplainJob): Explain job.
@@ -74,14 +77,17 @@ class SaliencyEncap(ExplanationEncap):
Returns: Returns:
dict, the edited sample info. dict, the edited sample info.
""" """
original = ImageQueryTypes.ORIGINAL.value
overlay = ImageQueryTypes.OVERLAY.value

sample_cp = sample.copy() sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original")
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], original)
for inference in sample_cp["inferences"]: for inference in sample_cp["inferences"]:
new_list = [] new_list = []
for saliency_map in inference[ExplanationKeys.SALIENCY.value]: for saliency_map in inference[ExplanationKeys.SALIENCY.value]:
if explainers and saliency_map["explainer"] not in explainers: if explainers and saliency_map["explainer"] not in explainers:
continue continue
saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay")
saliency_map[overlay] = self._get_image_url(job.train_id, saliency_map[overlay], overlay)
new_list.append(saliency_map) new_list.append(saliency_map)
inference[ExplanationKeys.SALIENCY.value] = new_list inference[ExplanationKeys.SALIENCY.value] = new_list
return sample_cp return sample_cp

+ 5
- 6
mindinsight/explainer/manager/explain_loader.py View File

@@ -270,7 +270,7 @@ class ExplainLoader:
Update the update_time manually. Update the update_time manually.


Args: Args:
new_time stamp (datetime.datetime or float): Updated time for the summary file.
new_time (datetime.datetime or float): Updated time for the summary file.
""" """
if isinstance(new_time, datetime): if isinstance(new_time, datetime):
self._loader_info['update_time'] = new_time.timestamp() self._loader_info['update_time'] = new_time.timestamp()
@@ -333,11 +333,10 @@ class ExplainLoader:


def get_all_samples(self) -> List[Dict]: def get_all_samples(self) -> List[Dict]:
""" """
Return a list of sample information cached in the explain job
Return a list of sample information cached in the explain job.


Returns: Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:
sample_list (list[SampleObj]): a list of sample objects, each object consists of:


- id (int): Sample id. - id (int): Sample id.
- name (str): Basename of image. - name (str): Basename of image.
@@ -406,7 +405,7 @@ class ExplainLoader:
} }


Args: Args:
benchmarks (benchmark_container): Parsed benchmarks data from summary file.
benchmarks (BenchmarkContainer): Parsed benchmarks data from summary file.
""" """
explainer_score = self._benchmark['explainer_score'] explainer_score = self._benchmark['explainer_score']
label_score = self._benchmark['label_score'] label_score = self._benchmark['label_score']
@@ -429,7 +428,7 @@ class ExplainLoader:
Parse the sample event. Parse the sample event.


Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored
in the following structure.
in the following structure:


- ground_truth_labels (list[int]): A list of ground truth labels of the sample. - ground_truth_labels (list[int]): A list of ground truth labels of the sample.
- ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model.


+ 15
- 13
mindinsight/explainer/manager/explain_manager.py View File

@@ -65,7 +65,7 @@ class ExplainManager:


Args: Args:
reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded
once. Default: 0.
once. Default: 0.
""" """
thread = threading.Thread(target=self._repeat_loading, thread = threading.Thread(target=self._repeat_loading,
name='explainer.start_load_thread', name='explainer.start_load_thread',
@@ -80,10 +80,10 @@ class ExplainManager:
If explain job w.r.t given loader_id is not found, None will be returned. If explain job w.r.t given loader_id is not found, None will be returned.


Args: Args:
loader_id (str): The id of expected ExplainLoader
loader_id (str): The id of expected ExplainLoader.


Return:
explain_job
Returns:
ExplainLoader, the data loader specified by loader_id.
""" """
self._check_status_valid() self._check_status_valid()


@@ -111,17 +111,19 @@ class ExplainManager:
Return List of explain jobs. includes job ID, create and update time. Return List of explain jobs. includes job ID, create and update time.


Args: Args:
offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default value is 0.
limit (int): The max data items for per page. Default value is 10.
offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default: 0.
limit (int): The max data items for per page. Default: 10.


Returns: Returns:
tuple[total, directories], total indicates the overall number of explain directories and directories
indicate list of summary directory info including the following attributes.
tuple, the elements of the returned tuple are:

- total (int): The overall number of explain directories
- dir_infos (list): List of summary directory info including the following attributes:


- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
""" """
total, dir_infos = \ total, dir_infos = \
self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit)
@@ -216,7 +218,7 @@ class ExplainManager:
return loader return loader


def _add_loader(self, loader): def _add_loader(self, loader):
"""add loader to the loader_pool."""
"""Add loader to the loader_pool."""
if loader.train_id not in self._loader_pool: if loader.train_id not in self._loader_pool:
self._loader_pool[loader.train_id] = loader self._loader_pool[loader.train_id] = loader
else: else:


+ 21
- 11
mindinsight/explainer/manager/explain_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -59,12 +59,13 @@ class ExplainParser(_SummaryParser):


Args: Args:
filenames (list[str]): File name list. filenames (list[str]): File name list.

Returns: Returns:
tuple, will return (file_changed, is_end, event_data),
tuple, the elements of the tuple are:


file_changed (bool): True if the 9latest file is changed.
is_end (bool): True if all the summary files are finished loading.
event_data (dict): return an event data, key is field.
- file_changed (bool): True if the latest file is changed.
- is_end (bool): True if all the summary files are finished loading.
- event_data (dict): Event data where keys are explanation field.
""" """
summary_files = self.sort_files(filenames) summary_files = self.sort_files(filenames)


@@ -134,6 +135,12 @@ class ExplainParser(_SummaryParser):


Args: Args:
event_str (str): Message event string in summary proto, data read from file handler. event_str (str): Message event string in summary proto, data read from file handler.

Returns:
tuple, the elements of the result tuple are:

- field_list (list): Explain fields to be parsed.
- tensor_value_list (list): Parsed data with respect to the field list.
""" """


logger.debug("Start to parse event string. Event string len: %s.", len(event_str)) logger.debug("Start to parse event string. Event string len: %s.", len(event_str))
@@ -172,10 +179,13 @@ class ExplainParser(_SummaryParser):
@staticmethod @staticmethod
def _add_image_data(tensor_event_value): def _add_image_data(tensor_event_value):
""" """
Parse image data based on sample_id in Explain message
Parse image data based on sample_id in Explain message.


Args: Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.

Returns:
SampleContainer, a named tuple containing sample data.
""" """
inference = InferfenceContainer( inference = InferfenceContainer(
ground_truth_prob=tensor_event_value.inference.ground_truth_prob, ground_truth_prob=tensor_event_value.inference.ground_truth_prob,
@@ -205,10 +215,10 @@ class ExplainParser(_SummaryParser):
Parse benchmark data from Explain message. Parse benchmark data from Explain message.


Args: Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.


Returns: Returns:
benchmark_data: An object containing benchmark.
BenchmarkContainer, a named tuple containing benchmark data.
""" """
benchmark_data = BenchmarkContainer( benchmark_data = BenchmarkContainer(
benchmark=tensor_event_value.benchmark, benchmark=tensor_event_value.benchmark,
@@ -223,10 +233,10 @@ class ExplainParser(_SummaryParser):
Parse metadata from Explain message. Parse metadata from Explain message.


Args: Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.


Returns: Returns:
benchmark_data: An object containing metadata.
MetadataContainer, a named tuple containing benchmark data.
""" """
metadata_value = MetadataContainer( metadata_value = MetadataContainer(
metadata=tensor_event_value.metadata, metadata=tensor_event_value.metadata,


Loading…
Cancel
Save