Browse Source

CI fix: shorten long functions. And format API comments.

tags/v1.2.0-rc1
lixiaohui 4 years ago
parent
commit
896129bb37
12 changed files with 261 additions and 124 deletions
  1. +124
    -21
      mindinsight/backend/explainer/explainer_api.py
  2. +5
    -10
      mindinsight/datavisual/data_transform/summary_watcher.py
  3. +14
    -1
      mindinsight/explainer/common/enums.py
  4. +3
    -3
      mindinsight/explainer/encapsulator/_hoc_pil_apply.py
  5. +43
    -35
      mindinsight/explainer/encapsulator/datafile_encap.py
  6. +8
    -13
      mindinsight/explainer/encapsulator/explain_data_encap.py
  7. +4
    -2
      mindinsight/explainer/encapsulator/explain_job_encap.py
  8. +9
    -5
      mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py
  9. +10
    -4
      mindinsight/explainer/encapsulator/saliency_encap.py
  10. +5
    -6
      mindinsight/explainer/manager/explain_loader.py
  11. +15
    -13
      mindinsight/explainer/manager/explain_manager.py
  12. +21
    -11
      mindinsight/explainer/manager/explain_parser.py

+ 124
- 21
mindinsight/backend/explainer/explainer_api.py View File

@@ -40,8 +40,52 @@ URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)


def _validate_type(param, name, expected_types):
"""
Common function to validate type.

Args:
param (object): Parameter to be validated.
name (str): Name of the parameter.
expected_types (type, tuple[type]): Expected type(s) of param.

Raises:
ParamTypeError: When param is not an instance of expected_types.
"""

if not isinstance(param, expected_types):
raise ParamTypeError(name, expected_types)


def _validate_value(param, name, expected_values):
"""
Common function to validate values of param.

Args:
param (object): Parameter to be validated.
name (str): Name of the parameter.
expected_values (tuple) : Expected values of param.

Raises:
ParamValueError: When param is not in expected_values.
"""

if param not in expected_values:
raise ParamValueError(f"Valid options for {name} are {expected_values}, but got {param}.")


def _image_url_formatter(train_id, image_path, image_type):
"""Returns image url."""
"""
Returns image url.

Args:
train_id (str): Id that specifies explain job.
image_path (str): Local path or unique string that specifies the image for query.
image_type (str): Image query type.

Returns:
str, url string for image query.
"""
data = {
"train_id": train_id,
"path": image_path,
@@ -69,38 +113,48 @@ def _read_post_request(post_request):


def _get_query_sample_parameters(data):
"""Get parameter for query."""
"""
Get parameter for query.

Args:
data (dict): Dict that contains request info.

Returns:
dict, key-value pairs to call backend query functions.

Raises:
ParamMissError: If train_id info is not in the request.
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""

train_id = data.get("train_id")
if train_id is None:
raise ParamMissError('train_id')

labels = data.get("labels")
if labels is not None and not isinstance(labels, list):
raise ParamTypeError("labels", (list, None))
if labels is not None:
_validate_type(labels, "labels", list)
if labels:
for item in labels:
if not isinstance(item, str):
raise ParamTypeError("element of labels", str)
_validate_type(item, "element of labels", str)

limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "")
_validate_value(sorted_name, "sorted_name", ('', 'confidence', 'uncertainty'))

sorted_type = data.get("sorted_type", "descending")
if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")
_validate_value(sorted_type, "sorted_type", ("ascending", "descending"))

prediction_types = data.get("prediction_types")
if prediction_types is not None and not isinstance(prediction_types, list):
raise ParamTypeError("prediction_types", (list, None))
if prediction_types is not None:
_validate_type(prediction_types, "element of labels", list)
if prediction_types:
for item in prediction_types:
if item not in ['TP', 'FN', 'FP']:
raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.")
_validate_value(item, "element of prediction_types", ('TP', 'FN', 'FP'))

query_kwarg = {"train_id": train_id,
"labels": labels,
@@ -114,7 +168,17 @@ def _get_query_sample_parameters(data):

@BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"])
def query_explain_jobs():
"""Query explain jobs."""
"""
Query explain jobs.

Returns:
Response, contains dict that stores base directory, total number of jobs and their detailed job metadata.

Raises:
ParamMissError: If train_id info is not in the request.
ParamTypeError: If one of (offset, limit) is not integer in the request.
ParamValueError: If one of (offset, limit) does not have the expected value in the request.
"""
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10)
offset = Validation.check_offset(offset=offset)
@@ -132,7 +196,15 @@ def query_explain_jobs():

@BLUEPRINT.route("/explainer/explain-job", methods=["GET"])
def query_explain_job():
"""Query explain job meta-data."""
"""
Query explain job meta-data.

Returns:
Response, contains dict that stores metadata of the requested job.

Raises:
ParamMissError: If train_id info is not in the request.
"""
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")
@@ -143,7 +215,16 @@ def query_explain_job():

@BLUEPRINT.route("/explainer/saliency", methods=["POST"])
def query_saliency():
"""Query saliency map related results."""
"""
Query saliency map related results.

Returns:
Response, contains dict that stores number of samples and the detailed sample info.

Raises:
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""
data = _read_post_request(request)
query_kwarg = _get_query_sample_parameters(data)
explainers = data.get("explainers")
@@ -169,7 +250,16 @@ def query_saliency():

@BLUEPRINT.route("/explainer/hoc", methods=["POST"])
def query_hoc():
"""Query hierarchical occlusion related results."""
"""
Query hierarchical occlusion related results.

Returns:
Response, contains dict that stores number of samples and the detailed sample info.

Raises:
ParamTypeError: If certain key is not in the expected type in the request.
ParamValueError: If certain key does not have the expected value in the request.
"""
data = _read_post_request(request)

query_kwargs = _get_query_sample_parameters(data)
@@ -193,7 +283,15 @@ def query_hoc():

@BLUEPRINT.route("/explainer/evaluation", methods=["GET"])
def query_evaluation():
"""Query saliency explainer evaluation scores."""
"""
Query saliency explainer evaluation scores.

Returns:
Response, contains dict that stores evaluation scores.

Raises:
ParamMissError: If train_id info is not in the request.
"""
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")
@@ -206,7 +304,12 @@ def query_evaluation():

@BLUEPRINT.route("/explainer/image", methods=["GET"])
def query_image():
"""Query image."""
"""
Query image.

Returns:
bytes, image binary content for UI to demonstrate.
"""
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")
@@ -230,6 +333,6 @@ def init_module(app):
Init module entry.

Args:
app: the application obj.
app (flask.app): The application obj.
"""
app.register_blueprint(BLUEPRINT)

+ 5
- 10
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -220,10 +220,7 @@ class SummaryWatcher:
summary_dict[relative_path].update(job_dict)
if summary_dict[relative_path]['create_time'] < ctime:
summary_dict[relative_path].update({
'create_time': ctime,
'update_time': mtime,
})
summary_dict[relative_path].update({'create_time': ctime, 'update_time': mtime})
job_dict = _get_explain_job_info(summary_base_dir, relative_path, timestamp)
summary_dict[relative_path].update(job_dict)
@@ -243,12 +240,10 @@ class SummaryWatcher:
if not is_find:
return
profiler = {
'directory': os.path.join('.', entry.name),
'create_time': ctime,
'update_time': mtime,
"profiler_type": profiler_type
}
profiler = {'directory': os.path.join('.', entry.name),
'create_time': ctime,
'update_time': mtime,
"profiler_type": profiler_type}
if relative_path in summary_dict:
summary_dict[relative_path]['profiler'] = profiler


+ 14
- 1
mindinsight/explainer/common/enums.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -50,3 +50,16 @@ class CacheStatus(enum.Enum):
NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING"
CACHED = "CACHED"


class ExplanationKeys(enum.Enum):
"""Query type enums."""
HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose
SALIENCY = "saliency_maps"


class ImageQueryTypes(enum.Enum):
"""Image query type enums."""
ORIGINAL = 'original' # Query for the original image
OUTCOME = 'outcome' # Query for outcome of HOC explanation
OVERLAY = 'overlay' # Query for saliency maps overlay

+ 3
- 3
mindinsight/explainer/encapsulator/_hoc_pil_apply.py View File

@@ -63,7 +63,7 @@ def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=Fals
Args:
image (PIL.Image): The input image in RGB mode.
mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object.
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], an RBG tuple or a PIL Image object.
edit_steps (list[EditStep]): Edit steps to be drawn.
by_masking (bool): Whether to use masking method. Default: False.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.
@@ -99,7 +99,7 @@ def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False):
Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
scale intensity [0, 255], an RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

@@ -132,7 +132,7 @@ def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False):
Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
scale intensity [0, 255], an RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.



+ 43
- 35
mindinsight/explainer/encapsulator/datafile_encap.py View File

@@ -21,6 +21,7 @@ import numpy as np
from PIL import Image

from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.explainer.common.enums import ImageQueryTypes
from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.utils.exceptions import FileSystemPermissionError
@@ -63,41 +64,10 @@ class DatafileEncap(ExplainDataEncap):
image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.

Returns:
bytes, image binary.
bytes, image binary content for UI to demonstrate.
"""
if image_type == "outcome":
sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()
if image_type == ImageQueryTypes.OUTCOME.value:
return self._get_hoc_image(image_path, train_id)

abs_image_path = os.path.join(self.job_manager.summary_base_dir,
_clean_train_id_b4_join(train_id),
@@ -108,7 +78,7 @@ class DatafileEncap(ExplainDataEncap):

try:

if image_type != "overlay":
if image_type != ImageQueryTypes.OVERLAY.value:
# no need to convert
with open(abs_image_path, "rb") as fp:
return fp.read()
@@ -153,3 +123,41 @@ class DatafileEncap(ExplainDataEncap):
base_dir = os.path.realpath(self.job_manager.summary_base_dir)
path = os.path.realpath(path)
return not path.startswith(base_dir)

def _get_hoc_image(self, image_path, train_id):
"""Get hoc image for image data demonstration in UI."""

sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")

image_type = ImageQueryTypes.OUTCOME.value
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()

+ 8
- 13
mindinsight/explainer/encapsulator/explain_data_encap.py View File

@@ -15,13 +15,13 @@
"""Common explain data encapsulator base class."""

import copy
from enum import Enum

from mindinsight.explainer.common.enums import ExplanationKeys
from mindinsight.utils.exceptions import ParamValueError


def _sort_key_min_confidence(sample, labels):
"""Samples sort key by the min. confidence."""
"""Samples sort key by the minimum confidence."""
min_confidence = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
@@ -32,7 +32,7 @@ def _sort_key_min_confidence(sample, labels):


def _sort_key_max_confidence(sample, labels):
"""Samples sort key by the max. confidence."""
"""Samples sort key by the maximum confidence."""
max_confidence = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
@@ -43,7 +43,7 @@ def _sort_key_max_confidence(sample, labels):


def _sort_key_min_confidence_sd(sample, labels):
"""Samples sort key by the min. confidence_sd."""
"""Samples sort key by the minimum confidence_sd."""
min_confidence_sd = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
@@ -55,7 +55,7 @@ def _sort_key_min_confidence_sd(sample, labels):


def _sort_key_max_confidence_sd(sample, labels):
"""Samples sort key by the max. confidence_sd."""
"""Samples sort key by the maximum confidence_sd."""
max_confidence_sd = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
@@ -65,11 +65,6 @@ def _sort_key_max_confidence_sd(sample, labels):
max_confidence_sd = confidence_sd
return max_confidence_sd

class ExplanationKeys(Enum):
"""Query type enums."""
HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose
SALIENCY = "saliency_maps"


class ExplainDataEncap:
"""Explain data encapsulator base class."""
@@ -105,9 +100,9 @@ class ExplanationEncap(ExplainDataEncap):
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction type filter.
drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without
hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations
will be filtered out.
drop_type (str, None): When it is None, all data will be kept. When it is 'hoc_layers', samples without
hoc explanations will be drop out. When it is 'saliency_maps', samples without saliency explanations
will be drop out.

Returns:
list[dict], samples to be queried.


+ 4
- 2
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -29,11 +29,13 @@ class ExplainJobEncap(ExplainDataEncap):
def query_explain_jobs(self, offset, limit):
"""
Query explain job list.

Args:
offset (int): Page offset.
limit (int): Max. no. of items to be returned.
limit (int): Maximum number of items to be returned.

Returns:
tuple[int, list[Dict]], total no. of jobs and job list.
tuple[int, list[Dict]], total number of jobs and job list.
"""
total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit)
job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos]


+ 9
- 5
mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py View File

@@ -15,7 +15,8 @@
"""Hierarchical Occlusion encapsulator."""

from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys
from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap


class HierarchicalOcclusionEncap(ExplanationEncap):
@@ -81,7 +82,10 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
Returns:
dict, the edited sample info.
"""
sample["image"] = self._get_image_url(job.train_id, sample["image"], "original")
original = ImageQueryTypes.ORIGINAL.value
outcome = ImageQueryTypes.OUTCOME.value

sample["image"] = self._get_image_url(job.train_id, sample["image"], original)
inferences = sample["inferences"]
i = 0 # init index for while loop
while i < len(inferences):
@@ -91,9 +95,9 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
continue
new_list = []
for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]):
hoc_layer["outcome"] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
"outcome")
hoc_layer[outcome] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
outcome)
new_list.append(hoc_layer)
inference_item[ExplanationKeys.HOC.value] = new_list
i += 1


+ 10
- 4
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -15,7 +15,8 @@
"""Saliency map encapsulator."""

from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys
from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap


class SaliencyEncap(ExplanationEncap):
@@ -32,6 +33,7 @@ class SaliencyEncap(ExplanationEncap):
prediction_types=None):
"""
Query saliency maps.

Args:
train_id (str): Job ID.
labels (list[str]): Label filter.
@@ -65,7 +67,8 @@ class SaliencyEncap(ExplanationEncap):

def _touch_sample(self, sample, job, explainers):
"""
Final editing the sample info.
Final edit on single sample info.

Args:
sample (dict): Sample info.
job (ExplainJob): Explain job.
@@ -74,14 +77,17 @@ class SaliencyEncap(ExplanationEncap):
Returns:
dict, the edited sample info.
"""
original = ImageQueryTypes.ORIGINAL.value
overlay = ImageQueryTypes.OVERLAY.value

sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original")
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], original)
for inference in sample_cp["inferences"]:
new_list = []
for saliency_map in inference[ExplanationKeys.SALIENCY.value]:
if explainers and saliency_map["explainer"] not in explainers:
continue
saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay")
saliency_map[overlay] = self._get_image_url(job.train_id, saliency_map[overlay], overlay)
new_list.append(saliency_map)
inference[ExplanationKeys.SALIENCY.value] = new_list
return sample_cp

+ 5
- 6
mindinsight/explainer/manager/explain_loader.py View File

@@ -270,7 +270,7 @@ class ExplainLoader:
Update the update_time manually.

Args:
new_time stamp (datetime.datetime or float): Updated time for the summary file.
new_time (datetime.datetime or float): Updated time for the summary file.
"""
if isinstance(new_time, datetime):
self._loader_info['update_time'] = new_time.timestamp()
@@ -333,11 +333,10 @@ class ExplainLoader:

def get_all_samples(self) -> List[Dict]:
"""
Return a list of sample information cached in the explain job
Return a list of sample information cached in the explain job.

Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:
sample_list (list[SampleObj]): a list of sample objects, each object consists of:

- id (int): Sample id.
- name (str): Basename of image.
@@ -406,7 +405,7 @@ class ExplainLoader:
}

Args:
benchmarks (benchmark_container): Parsed benchmarks data from summary file.
benchmarks (BenchmarkContainer): Parsed benchmarks data from summary file.
"""
explainer_score = self._benchmark['explainer_score']
label_score = self._benchmark['label_score']
@@ -429,7 +428,7 @@ class ExplainLoader:
Parse the sample event.

Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored
in the following structure.
in the following structure:

- ground_truth_labels (list[int]): A list of ground truth labels of the sample.
- ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model.


+ 15
- 13
mindinsight/explainer/manager/explain_manager.py View File

@@ -65,7 +65,7 @@ class ExplainManager:

Args:
reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded
once. Default: 0.
once. Default: 0.
"""
thread = threading.Thread(target=self._repeat_loading,
name='explainer.start_load_thread',
@@ -80,10 +80,10 @@ class ExplainManager:
If explain job w.r.t given loader_id is not found, None will be returned.

Args:
loader_id (str): The id of expected ExplainLoader
loader_id (str): The id of expected ExplainLoader.

Return:
explain_job
Returns:
ExplainLoader, the data loader specified by loader_id.
"""
self._check_status_valid()

@@ -111,17 +111,19 @@ class ExplainManager:
Return List of explain jobs. includes job ID, create and update time.

Args:
offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default value is 0.
limit (int): The max data items for per page. Default value is 10.
offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default: 0.
limit (int): The max data items for per page. Default: 10.

Returns:
tuple[total, directories], total indicates the overall number of explain directories and directories
indicate list of summary directory info including the following attributes.
tuple, the elements of the returned tuple are:

- total (int): The overall number of explain directories
- dir_infos (list): List of summary directory info including the following attributes:

- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
"""
total, dir_infos = \
self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit)
@@ -216,7 +218,7 @@ class ExplainManager:
return loader

def _add_loader(self, loader):
"""add loader to the loader_pool."""
"""Add loader to the loader_pool."""
if loader.train_id not in self._loader_pool:
self._loader_pool[loader.train_id] = loader
else:


+ 21
- 11
mindinsight/explainer/manager/explain_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -59,12 +59,13 @@ class ExplainParser(_SummaryParser):

Args:
filenames (list[str]): File name list.

Returns:
tuple, will return (file_changed, is_end, event_data),
tuple, the elements of the tuple are:

file_changed (bool): True if the 9latest file is changed.
is_end (bool): True if all the summary files are finished loading.
event_data (dict): return an event data, key is field.
- file_changed (bool): True if the latest file is changed.
- is_end (bool): True if all the summary files are finished loading.
- event_data (dict): Event data where keys are explanation field.
"""
summary_files = self.sort_files(filenames)

@@ -134,6 +135,12 @@ class ExplainParser(_SummaryParser):

Args:
event_str (str): Message event string in summary proto, data read from file handler.

Returns:
tuple, the elements of the result tuple are:

- field_list (list): Explain fields to be parsed.
- tensor_value_list (list): Parsed data with respect to the field list.
"""

logger.debug("Start to parse event string. Event string len: %s.", len(event_str))
@@ -172,10 +179,13 @@ class ExplainParser(_SummaryParser):
@staticmethod
def _add_image_data(tensor_event_value):
"""
Parse image data based on sample_id in Explain message
Parse image data based on sample_id in Explain message.

Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.

Returns:
SampleContainer, a named tuple containing sample data.
"""
inference = InferfenceContainer(
ground_truth_prob=tensor_event_value.inference.ground_truth_prob,
@@ -205,10 +215,10 @@ class ExplainParser(_SummaryParser):
Parse benchmark data from Explain message.

Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.

Returns:
benchmark_data: An object containing benchmark.
BenchmarkContainer, a named tuple containing benchmark data.
"""
benchmark_data = BenchmarkContainer(
benchmark=tensor_event_value.benchmark,
@@ -223,10 +233,10 @@ class ExplainParser(_SummaryParser):
Parse metadata from Explain message.

Args:
tensor_event_value: the object of Explain message
tensor_event_value (Event): The object of Explain message.

Returns:
benchmark_data: An object containing metadata.
MetadataContainer, a named tuple containing benchmark data.
"""
metadata_value = MetadataContainer(
metadata=tensor_event_value.metadata,


Loading…
Cancel
Save