Browse Source

rewrite ExplainManager and ExplainLoader

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
3a05552c7e
10 changed files with 822 additions and 904 deletions
  1. +3
    -0
      mindinsight/backend/explainer/__init__.py
  2. +7
    -24
      mindinsight/backend/explainer/explainer_api.py
  3. +1
    -1
      mindinsight/explainer/common/enums.py
  4. +1
    -1
      mindinsight/explainer/encapsulator/explain_job_encap.py
  5. +0
    -168
      mindinsight/explainer/manager/event_parse.py
  6. +0
    -398
      mindinsight/explainer/manager/explain_job.py
  7. +580
    -0
      mindinsight/explainer/manager/explain_loader.py
  8. +178
    -255
      mindinsight/explainer/manager/explain_manager.py
  9. +51
    -56
      mindinsight/explainer/manager/explain_parser.py
  10. +1
    -1
      tests/ut/explainer/encapsulator/mock_explain_manager.py

+ 3
- 0
mindinsight/backend/explainer/__init__.py View File

@@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Module init file."""
from mindinsight.conf import settings
from mindinsight.backend.explainer.explainer_api import init_module as init_query_module
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER


def init_module(app):
@@ -27,3 +29,4 @@ def init_module(app):

"""
init_query_module(app)
EXPLAIN_MANAGER.start_load_data(reload_interval=settings.RELOAD_INTERVAL)

+ 7
- 24
mindinsight/backend/explainer/explainer_api.py View File

@@ -29,32 +29,16 @@ from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.explainer.manager.explain_manager import ExplainManager
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER
from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap


URL_PREFIX = settings.URL_PATH_PREFIX+settings.API_PREFIX
URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)


class ExplainManagerHolder:
"""ExplainManger instance holder."""

static_instance = None

@classmethod
def get_instance(cls):
return cls.static_instance

@classmethod
def initialize(cls):
cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR)
cls.static_instance.start_load_data()


def _image_url_formatter(train_id, image_path, image_type):
"""Returns image url."""
data = {
@@ -91,7 +75,7 @@ def query_explain_jobs():
offset = Validation.check_offset(offset=offset)
limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT)

encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance())
encapsulator = ExplainJobEncap(EXPLAIN_MANAGER)
total, jobs = encapsulator.query_explain_jobs(offset, limit)

return jsonify({
@@ -107,7 +91,7 @@ def query_explain_job():
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")
encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance())
encapsulator = ExplainJobEncap(EXPLAIN_MANAGER)
metadata = encapsulator.query_meta(train_id)

return jsonify(metadata)
@@ -139,7 +123,7 @@ def query_saliency():

encapsulator = SaliencyEncap(
_image_url_formatter,
ExplainManagerHolder.get_instance())
EXPLAIN_MANAGER)
count, samples = encapsulator.query_saliency_maps(train_id=train_id,
labels=labels,
explainers=explainers,
@@ -160,7 +144,7 @@ def query_evaluation():
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")
encapsulator = EvaluationEncap(ExplainManagerHolder.get_instance())
encapsulator = EvaluationEncap(EXPLAIN_MANAGER)
scores = encapsulator.query_explainer_scores(train_id)
return jsonify({
"explainer_scores": scores,
@@ -182,7 +166,7 @@ def query_image():
if image_type not in ("original", "overlay"):
raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'")

encapsulator = DatafileEncap(ExplainManagerHolder.get_instance())
encapsulator = DatafileEncap(EXPLAIN_MANAGER)
image = encapsulator.query_image_binary(train_id, image_path, image_type)
if image is None:
raise ImageNotExistError(f"{image_path}")
@@ -198,5 +182,4 @@ def init_module(app):
app: the application obj.

"""
ExplainManagerHolder.initialize()
app.register_blueprint(BLUEPRINT)

+ 1
- 1
mindinsight/explainer/common/enums.py View File

@@ -31,7 +31,7 @@ class DataManagerStatus(BaseEnum):
INVALID = 'INVALID'


class PluginNameEnum(BaseEnum):
class ExplainFieldsEnum(BaseEnum):
"""Plugin Name Enum."""
EXPLAIN = 'explain'
SAMPLE_ID = 'sample_id'


+ 1
- 1
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -70,7 +70,7 @@ class ExplainJobEncap(ExplainDataEncap):
info["train_id"] = job.train_id
info["create_time"] = datetime.fromtimestamp(job.create_time)\
.strftime(cls.DATETIME_FORMAT)
info["update_time"] = datetime.fromtimestamp(job.latest_update_time)\
info["update_time"] = datetime.fromtimestamp(job.update_time)\
.strftime(cls.DATETIME_FORMAT)
return info



+ 0
- 168
mindinsight/explainer/manager/event_parse.py View File

@@ -1,168 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EventParser for summary event."""
from collections import namedtuple, defaultdict
from typing import Dict, List, Optional, Tuple

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
from mindinsight.utils.exceptions import UnknownError

_IMAGE_DATA_TAGS = {
'sample_id': PluginNameEnum.SAMPLE_ID.value,
'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value,
'inference': PluginNameEnum.INFERENCE.value,
'explanation': PluginNameEnum.EXPLANATION.value
}

_NUM_DIGIT = 7


class EventParser:
"""Parser for event data."""

def __init__(self, job):
self._job = job
self._sample_pool = {}

@staticmethod
def parse_metadata(metadata) -> Tuple[List, List, List]:
"""Parse the metadata event."""
explainers = list(metadata.explain_method)
metrics = list(metadata.benchmark_method)
labels = list(metadata.label)
return explainers, metrics, labels

@staticmethod
def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]:
"""Parse the benchmark event."""
explainer_score_dict = defaultdict(list)
label_score_dict = defaultdict(dict)

for benchmark in benchmarks:
explainer = benchmark.explain_method
metric = benchmark.benchmark_method
metric_score = benchmark.total_score
label_score_event = benchmark.label_score

explainer_score_dict[explainer].append({
'metric': metric,
'score': round(metric_score, _NUM_DIGIT)})
new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric)
for label, label_scores in new_label_score_dict.items():
label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores

return explainer_score_dict, label_score_dict

def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]:
"""Parse the sample event."""
sample_id = sample.sample_id

if sample_id not in self._sample_pool:
self._sample_pool[sample_id] = sample
return None

for tag in _IMAGE_DATA_TAGS:
try:
if tag == PluginNameEnum.INFERENCE.value:
self._parse_inference(sample, sample_id)
elif tag == PluginNameEnum.EXPLANATION.value:
self._parse_explanation(sample, sample_id)
else:
self._parse_sample_info(sample, sample_id, tag)
except UnknownError as ex:
logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex))
continue

if EventParser._is_ready_for_display(self._sample_pool[sample_id]):
return self._sample_pool[sample_id]
return None

def clear(self):
"""Clear the loaded data."""
self._sample_pool.clear()

@staticmethod
def _is_ready_for_display(image_container: namedtuple) -> bool:
"""
Check whether the image_container is ready for frontend display.

Args:
image_container (namedtuple): container consists of sample data

Return:
bool: whether the image_container if ready for display
"""
required_attrs = ['image_path', 'ground_truth_label', 'inference']
for attr in required_attrs:
if not EventParser.is_attr_ready(image_container, attr):
return False
return True

@staticmethod
def is_attr_ready(image_container: namedtuple, attr: str) -> bool:
"""
Check whether the given attribute is ready in image_container.

Args:
image_container (namedtuple): container consist of sample data
attr (str): attribute to check

Returns:
bool, whether the attr is ready
"""
if getattr(image_container, attr, False):
return True
return False

@staticmethod
def _score_event_to_dict(label_score_event, metric):
"""Transfer metric scores per label to pre-defined structure."""
new_label_score_dict = defaultdict(list)
for label_id, label_score in enumerate(label_score_event):
new_label_score_dict[label_id].append({
'metric': metric,
'score': round(label_score, _NUM_DIGIT),
})
return new_label_score_dict

def _parse_inference(self, event, sample_id):
"""Parse the inference event."""
self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob)
self._sample_pool[sample_id].inference.ground_truth_prob_sd.extend(event.inference.ground_truth_prob_sd)
self._sample_pool[sample_id].inference.ground_truth_prob_itl95_low.\
extend(event.inference.ground_truth_prob_itl95_low)
self._sample_pool[sample_id].inference.ground_truth_prob_itl95_hi.\
extend(event.inference.ground_truth_prob_itl95_hi)

self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label)
self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob)
self._sample_pool[sample_id].inference.predicted_prob_sd.extend(event.inference.predicted_prob_sd)
self._sample_pool[sample_id].inference.predicted_prob_itl95_low.extend(event.inference.predicted_prob_itl95_low)
self._sample_pool[sample_id].inference.predicted_prob_itl95_hi.extend(event.inference.predicted_prob_itl95_hi)

def _parse_explanation(self, event, sample_id):
"""Parse the explanation event."""
if event.explanation:
for explanation_item in event.explanation:
new_explanation = self._sample_pool[sample_id].explanation.add()
new_explanation.explain_method = explanation_item.explain_method
new_explanation.label = explanation_item.label
new_explanation.heatmap_path = explanation_item.heatmap_path

def _parse_sample_info(self, event, sample_id, tag):
"""Parse the event containing image info."""
if not getattr(self._sample_pool[sample_id], tag):
setattr(self._sample_pool[sample_id], tag, getattr(event, tag))

+ 0
- 398
mindinsight/explainer/manager/explain_job.py View File

@@ -1,398 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExplainJob."""

import os
from collections import defaultdict
from datetime import datetime
from typing import Union

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_parser import _ExplainParser
from mindinsight.explainer.manager.event_parse import EventParser
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError

_NUM_DIGIT = 7


class ExplainJob:
"""ExplainJob which manage the record in the summary file."""

def __init__(self,
job_id: str,
summary_dir: str,
create_time: float,
latest_update_time: float):

self._job_id = job_id
self._summary_dir = summary_dir
self._parser = _ExplainParser(summary_dir)

self._event_parser = EventParser(self)
self._latest_update_time = latest_update_time
self._create_time = create_time
self._uncertainty_enabled = False
self._labels = []
self._metrics = []
self._explainers = []
self._samples_info = {}
self._labels_info = {}
self._explainer_score_dict = defaultdict(list)
self._label_score_dict = defaultdict(dict)

@property
def all_classes(self):
"""
Return a list of label info

Returns:
class_objs (List[ClassObj]): a list of class_objects, each object
contains:

- id (int): label id
- label (str): label name
- sample_count (int): number of samples for each label
"""
all_classes_return = []
for label_id, label_info in self._labels_info.items():
single_info = {
'id': label_id,
'label': label_info['label'],
'sample_count': len(label_info['sample_ids'])}
all_classes_return.append(single_info)
return all_classes_return

@property
def explainers(self):
"""
Return a list of explainer names

Returns:
list(str), explainer names
"""
return self._explainers

@property
def explainer_scores(self):
"""Return evaluation results for every explainer."""
merged_scores = []
for explainer, explainer_score_on_metric in self._explainer_score_dict.items():
label_scores = []
for label, label_score_on_metric in self._label_score_dict[explainer].items():
score_single_label = {
'label': self._labels[label],
'evaluations': label_score_on_metric,
}
label_scores.append(score_single_label)
merged_scores.append({
'explainer': explainer,
'evaluations': explainer_score_on_metric,
'class_scores': label_scores,
})
return merged_scores

@property
def sample_count(self):
"""
Return total number of samples in the job.

Return:
int, total number of samples

"""
return len(self._samples_info)

@property
def train_id(self):
"""
Return ID of explain job

Returns:
str, id of ExplainJob object
"""
return self._job_id

@property
def metrics(self):
"""
Return a list of metric names

Returns:
list(str), metric names
"""
return self._metrics

@property
def min_confidence(self):
"""
Return minimum confidence

Returns:
min_confidence (float):
"""
return None

@property
def uncertainty_enabled(self):
return self._uncertainty_enabled

@property
def create_time(self):
"""
Return the create time of summary file

Returns:
creation timestamp (float)

"""
return self._create_time

@property
def labels(self):
"""Return the label contained in the job."""
return self._labels

@property
def latest_update_time(self):
"""
Return last modification time stamp of summary file.

Returns:
float, last_modification_time stamp
"""
return self._latest_update_time

@latest_update_time.setter
def latest_update_time(self, new_time: Union[float, datetime]):
"""
Update the latest_update_time timestamp manually.

Args:
new_time stamp (union[float, datetime]): updated time for the job
"""
if isinstance(new_time, datetime):
self._latest_update_time = new_time.timestamp()
elif isinstance(new_time, float):
self._latest_update_time = new_time
else:
raise TypeError('new_time should have type of float or datetime')

@property
def loader_id(self):
"""Return the job id."""
return self._job_id

@property
def samples(self):
"""Return the information of all samples in the job."""
return self._samples_info

@staticmethod
def get_create_time(file_path: str) -> float:
"""Return timestamp of create time of specific path."""
create_time = os.stat(file_path).st_ctime
return create_time

@staticmethod
def get_update_time(file_path: str) -> float:
"""Return timestamp of update time of specific path."""
update_time = os.stat(file_path).st_mtime
return update_time

def _initialize_labels_info(self):
"""Initialize a dict for labels in the job."""
if self._labels is None:
logger.warning('No labels is provided in job %s', self._job_id)
return

for label_id, label in enumerate(self._labels):
self._labels_info[label_id] = {'label': label,
'sample_ids': set()}

def _explanation_to_dict(self, explanation):
"""Transfer the explanation from event to dict storage."""
explain_info = {
'explainer': explanation.explain_method,
'overlay': explanation.heatmap_path,
}
return explain_info

def _image_container_to_dict(self, sample_data):
"""Transfer the image container to dict storage."""
has_uncertainty = False
sample_id = sample_data.sample_id

sample_info = {
'id': sample_id,
'image': sample_data.image_path,
'name': str(sample_id),
'labels': [self._labels_info[x]['label']
for x in sample_data.ground_truth_label],
'inferences': []}

ground_truth_labels = list(sample_data.ground_truth_label)
ground_truth_probs = list(sample_data.inference.ground_truth_prob)
predicted_labels = list(sample_data.inference.predicted_label)
predicted_probs = list(sample_data.inference.predicted_prob)

if sample_data.inference.predicted_prob_sd or sample_data.inference.ground_truth_prob_sd:
ground_truth_prob_sds = list(sample_data.inference.ground_truth_prob_sd)
ground_truth_prob_lows = list(sample_data.inference.ground_truth_prob_itl95_low)
ground_truth_prob_his = list(sample_data.inference.ground_truth_prob_itl95_hi)
predicted_prob_sds = list(sample_data.inference.predicted_prob_sd)
predicted_prob_lows = list(sample_data.inference.predicted_prob_itl95_low)
predicted_prob_his = list(sample_data.inference.predicted_prob_itl95_hi)
has_uncertainty = True
else:
ground_truth_prob_sds = ground_truth_prob_lows = ground_truth_prob_his = None
predicted_prob_sds = predicted_prob_lows = predicted_prob_his = None

inference_info = {}
for label, prob in zip(
ground_truth_labels + predicted_labels,
ground_truth_probs + predicted_probs):
inference_info[label] = {
'label': self._labels_info[label]['label'],
'confidence': round(prob, _NUM_DIGIT),
'saliency_maps': []}

if ground_truth_prob_sds or predicted_prob_sds:
for label, sd, low, hi in zip(
ground_truth_labels + predicted_labels,
ground_truth_prob_sds + predicted_prob_sds,
ground_truth_prob_lows + predicted_prob_lows,
ground_truth_prob_his + predicted_prob_his):
inference_info[label]['confidence_sd'] = sd
inference_info[label]['confidence_itl95'] = [low, hi]

if EventParser.is_attr_ready(sample_data, 'explanation'):
for explanation in sample_data.explanation:
explanation_dict = self._explanation_to_dict(explanation)
inference_info[explanation.label]['saliency_maps'].append(explanation_dict)

sample_info['inferences'] = list(inference_info.values())
return sample_info, has_uncertainty

def _import_sample(self, sample):
"""Add sample object of given sample id."""
for label_id in sample.ground_truth_label:
self._labels_info[label_id]['sample_ids'].add(sample.sample_id)

sample_info, has_uncertainty = self._image_container_to_dict(sample)
self._samples_info.update({sample_info['id']: sample_info})
self._uncertainty_enabled |= has_uncertainty

def get_all_samples(self):
"""
Return a list of sample information cachced in the explain job

Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:

- id (int): sample id
- name (str): basename of image
- labels (list[str]): list of labels
- inferences list[dict])
"""
samples_in_list = list(self._samples_info.values())
return samples_in_list

def _is_metadata_empty(self):
"""Check whether metadata is loaded first."""
if not self._explainers or not self._metrics or not self._labels:
return True
return False

def _import_data_from_event(self, event):
"""Parse and import data from the event data."""
tags = {
'sample_id': PluginNameEnum.SAMPLE_ID,
'benchmark': PluginNameEnum.BENCHMARK,
'metadata': PluginNameEnum.METADATA
}

if 'metadata' not in event and self._is_metadata_empty():
raise ValueError('metadata is empty, should write metadata first in the summary.')
for tag in tags:
if tag not in event:
continue

if tag == PluginNameEnum.SAMPLE_ID.value:
sample_event = event[tag]
sample_data = self._event_parser.parse_sample(sample_event)
if sample_data is not None:
self._import_sample(sample_data)
continue

if tag == PluginNameEnum.BENCHMARK.value:
benchmark_event = event[tag].benchmark
explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event)
self._update_benchmark(explain_score_dict, label_score_dict)

elif tag == PluginNameEnum.METADATA.value:
metadata_event = event[tag].metadata
metadata = EventParser.parse_metadata(metadata_event)
self._explainers, self._metrics, self._labels = metadata
self._initialize_labels_info()

def load(self):
"""
Start loading data from parser.
"""
valid_file_names = []
for filename in FileHandler.list_dir(self._summary_dir):
if FileHandler.is_file(
FileHandler.join(self._summary_dir, filename)):
valid_file_names.append(filename)

if not valid_file_names:
raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir)

is_end = False
while not is_end:
is_clean, is_end, event = self._parser.parse_explain(valid_file_names)

if is_clean:
logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir)
self._clean_job()

if event:
self._import_data_from_event(event)

def _clean_job(self):
"""Clean the cached data in job."""
self._latest_update_time = ExplainJob.get_update_time(self._summary_dir)
self._create_time = ExplainJob.get_update_time(self._summary_dir)
self._labels.clear()
self._metrics.clear()
self._explainers.clear()
self._samples_info.clear()
self._labels_info.clear()
self._explainer_score_dict.clear()
self._label_score_dict.clear()
self._event_parser.clear()

def _update_benchmark(self, explainer_score_dict, labels_score_dict):
"""Update the benchmark info."""
for explainer, score in explainer_score_dict.items():
self._explainer_score_dict[explainer].extend(score)

for explainer, score in labels_score_dict.items():
for label, score_of_label in score.items():
self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, [])
+ score_of_label)

+ 580
- 0
mindinsight/explainer/manager/explain_loader.py View File

@@ -0,0 +1,580 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExplainLoader."""

import os
import re
from collections import defaultdict
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Union

from mindinsight.explainer.common.enums import ExplainFieldsEnum
from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_parser import ExplainParser
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.utils.exceptions import ParamValueError, UnknownError

_NUM_DIGITS = 6

_EXPLAIN_FIELD_NAMES = [
ExplainFieldsEnum.SAMPLE_ID,
ExplainFieldsEnum.BENCHMARK,
ExplainFieldsEnum.METADATA,
]

_SAMPLE_FIELD_NAMES = [
ExplainFieldsEnum.GROUND_TRUTH_LABEL,
ExplainFieldsEnum.INFERENCE,
ExplainFieldsEnum.EXPLANATION,
]


def _round(score):
"""Take round of a number to given precision."""
return round(score, _NUM_DIGITS)


class ExplainLoader:
"""ExplainLoader which manage the record in the summary file."""

def __init__(self,
loader_id: str,
summary_dir: str):

self._parser = ExplainParser(summary_dir)

self._loader_info = {
'loader_id': loader_id,
'summary_dir': summary_dir,
'create_time': os.stat(summary_dir).st_ctime,
'update_time': os.stat(summary_dir).st_mtime,
'query_time': os.stat(summary_dir).st_ctime,
'uncertainty_enabled': False,
}
self._samples = defaultdict(dict)
self._metadata = {'explainers': [], 'metrics': [], 'labels': []}
self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)}

@property
def all_classes(self) -> List[Dict]:
"""
Return a list of detailed label information, including label id, label name and sample count of each label.

Returns:
list[dict], a list of dict, each dict contains:
- id (int): label id
- label (str): label name
- sample_count (int): number of samples for each label
"""
sample_count_per_label = defaultdict(int)
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
for label in sample['ground_truth_label']:
sample_count_per_label[label] += 1

all_classes_return = []
for label_id, label_name in enumerate(self._metadata['labels']):
single_info = {
'id': label_id,
'label': label_name,
'sample_count': sample_count_per_label[label_id]
}
all_classes_return.append(single_info)
return all_classes_return

@property
def query_time(self) -> float:
"""Return query timestamp of explain loader."""
return self._loader_info['query_time']

@query_time.setter
def query_time(self, new_time: Union[datetime, float]):
"""
Update the query_time timestamp manually.

Args:
new_time (datetime.datetime or float): Updated query_time for the explain loader.
"""
if isinstance(new_time, datetime):
self._loader_info['query_time'] = new_time.timestamp()
elif isinstance(new_time, float):
self._loader_info['query_time'] = new_time
else:
raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
.format(type(new_time)))

@property
def create_time(self) -> float:
"""Return the create timestamp of summary file."""
return self._loader_info['create_time']

@create_time.setter
def create_time(self, new_time: Union[datetime, float]):
"""
Update the create_time manually

Args:
new_time (datetime.datetime or float): Updated create_time of summary_file.
"""
if isinstance(new_time, datetime):
self._loader_info['create_time'] = new_time.timestamp()
elif isinstance(new_time, float):
self._loader_info['create_time'] = new_time
else:
raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
.format(type(new_time)))

@property
def explainers(self) -> List[str]:
"""Return a list of explainer names recorded in the summary file."""
return self._metadata['explainers']

@property
def explainer_scores(self) -> List[Dict]:
"""
Return evaluation results for every explainer.

Returns:
list[dict], A list of evaluation results of each explainer. Each item contains:
- explainer (str): Name of evaluated explainer.
- evaluations (list[dict]): A list of evlauation results by different metrics.
- class_scores (list[dict]): A list of evaluation results on different labels.

Each item in the evaluations contains:
- metric (str): name of metric method
- score (float): evaluation result

Each item in the class_scores contains:
- label (str): Name of label
- evaluations (list[dict]): A list of evalution results on different labels by different metrics.

Each item in evaluations contains:
- metric (str): Name of metric method
- score (float): Evaluation scores of explainer on specific label by the metric.
"""
explainer_scores = []
for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items():
metric_scores = [{'metric': metric, 'score': _round(score)}
for metric, score in explainer_score_on_metric.items()]
label_scores = []
for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items():
score_of_single_label = {
'label': self._metadata['labels'][label],
'evaluations': [
{'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items()
],
}
label_scores.append(score_of_single_label)
explainer_scores.append({
'explainer': explainer,
'evaluations': metric_scores,
'class_scores': label_scores,
})
return explainer_scores

@property
def labels(self) -> List[str]:
"""Return the label recorded in the summary."""
return self._metadata['labels']

@property
def metrics(self) -> List[str]:
"""Return a list of metric names recorded in the summary file."""
return self._metadata['metrics']

@property
def min_confidence(self) -> Optional[float]:
"""Return minimum confidence used to filter the predicted labels."""
return None

@property
def sample_count(self) -> int:
"""
Return total number of samples in the loader.

Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in
cache), the returned count only takes the available samples into account.

Return:
int, total number of available samples in the loading job.

"""
sample_count = 0
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
sample_count += 1
return sample_count

@property
def samples(self) -> List[Dict]:
"""Return the information of all samples in the job."""
return self.get_all_samples()

@property
def train_id(self) -> str:
"""Return ID of explain loader."""
return self._loader_info['loader_id']

@property
def uncertainty_enabled(self):
"""Whethter uncertainty is enabled."""
return self._loader_info['uncertainty_enabled']

@property
def update_time(self) -> float:
"""Return latest modification timestamp of summary file."""
return self._loader_info['update_time']

@update_time.setter
def update_time(self, new_time: Union[datetime, float]):
"""
Update the update_time manually.

Args:
new_time stamp (datetime.datetime or float): Updated time for the summary file.
"""
if isinstance(new_time, datetime):
self._loader_info['update_time'] = new_time.timestamp()
elif isinstance(new_time, float):
self._loader_info['update_time'] = new_time
else:
raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
.format(type(new_time)))

def load(self):
"""Start loading data from the latest summary file to the loader."""
filenames = []
for filename in FileHandler.list_dir(self._loader_info['summary_dir']):
if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)):
filenames.append(filename)
filenames = ExplainLoader._filter_files(filenames)

if not filenames:
raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.'
% self._loader_info['summary_dir'])

is_end = False
while not is_end:
is_clean, is_end, event_dict = self._parser.parse_explain(filenames)

if is_clean:
logger.info('Summary file in %s update, reload the data in the summary.',
self._loader_info['summary_dir'])
self._clear_job()
if event_dict:
self._import_data_from_event(event_dict)

def get_all_samples(self) -> List[Dict]:
"""
Return a list of sample information cachced in the explain job

Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:

- id (int): sample id
- name (str): basename of image
- labels (list[str]): list of labels
- inferences list[dict])
"""
returned_samples = []
samples_copy = self._samples.copy()
for sample_id, sample_info in samples_copy.items():
if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False):
continue
returned_sample = {
'id': sample_id,
'name': str(sample_id),
'image': sample_info['image'],
'labels': sample_info['ground_truth_label'],
}

if not ExplainLoader._is_inference_valid(sample_info):
continue

inferences = {}
for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob'] + sample_info['predicted_prob']):
inferences[label] = {
'label': self._metadata['labels'][label],
'confidence': _round(prob),
'saliency_maps': []
}

if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']:
for label, std, low, high in zip(
sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'],
sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'],
sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi']
):
inferences[label]['confidence_sd'] = _round(std)
inferences[label]['confidence_itl95'] = [_round(low), _round(high)]

for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
for label, heatmap_path in label_heatmap_path_dict.items():
if label in inferences:
inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path})

returned_sample['inferences'] = list(inferences.values())
returned_samples.append(returned_sample)
return returned_samples

def _import_data_from_event(self, event_dict: Dict):
"""Parse and import data from the event data."""
if 'metadata' not in event_dict and self._is_metadata_empty():
raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.')

for tag, event in event_dict.items():
if tag == ExplainFieldsEnum.METADATA.value:
self._import_metadata_from_event(event.metadata)
elif tag == ExplainFieldsEnum.BENCHMARK.value:
self._import_benchmark_from_event(event.benchmark)
elif tag == ExplainFieldsEnum.SAMPLE_ID.value:
self._import_sample_from_event(event)
else:
logger.info('Unknown ExplainField: %s', tag)

def _is_metadata_empty(self):
"""Check whether metadata is completely loaded first."""
if not self._metadata['labels']:
return True
return False

def _import_metadata_from_event(self, metadata_event):
"""Import the metadata from event into loader."""

def take_union(existed_list, imported_data):
"""Take union of existed_list and imported_data."""
if isinstance(imported_data, Iterable):
for sample in imported_data:
if sample not in existed_list:
existed_list.append(sample)

take_union(self._metadata['explainers'], metadata_event.explain_method)
take_union(self._metadata['metrics'], metadata_event.benchmark_method)
take_union(self._metadata['labels'], metadata_event.label)

def _import_benchmark_from_event(self, benchmarks):
"""
Parse the benchmark event.

Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall
evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results
w.r.t different labels.

The structure of self._benchmark['explainer_score'] demonstrates below:
{
explainer_1: {metric_name_1: score_1, ...},
explainer_2: {metric_name_1: score_1, ...},
...
}

The structure of self._benchmark['label_score'] is:
{
explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...},
explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...},
...
}

Args:
benchmarks (benchmark_container): Parsed benchmarks data from summary file.
"""
explainer_score = self._benchmark['explainer_score']
label_score = self._benchmark['label_score']

for benchmark in benchmarks:
explainer = benchmark.explain_method
metric = benchmark.benchmark_method
metric_score = benchmark.total_score
label_score_event = benchmark.label_score

explainer_score[explainer][metric] = metric_score
new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric)
for label, scores_of_metric in new_label_score_dict.items():
if label not in label_score[explainer]:
label_score[explainer][label] = {}
label_score[explainer][label].update(scores_of_metric)

def _import_sample_from_event(self, sample):
"""
Parse the sample event.

Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored
in the following structure.

- ground_truth_labels (list[int]): A list of ground truth labels of the sample.
- ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model.
- predicted_labels (list[int]): A list of predicted labels from the black-box model.
- predicted_probs (list[int]): A list of confidences w.r.t the predicted labels.
- explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary
of saliency maps. The structure of explanations demonstrates below:

{
explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
...
}
"""
if not getattr(sample, 'sample_id', None):
raise ParamValueError('sample_event has no sample_id')
sample_id = sample.sample_id
samples_copy = self._samples.copy()
if sample_id not in samples_copy:
self._samples[sample_id] = {
'ground_truth_label': [],
'ground_truth_prob': [],
'ground_truth_prob_sd': [],
'ground_truth_prob_itl95_low': [],
'ground_truth_prob_itl95_hi': [],
'predicted_label': [],
'predicted_prob': [],
'predicted_prob_sd': [],
'predicted_prob_itl95_low': [],
'predicted_prob_itl95_hi': [],
'explanation': defaultdict(dict)
}

if sample.image_path:
self._samples[sample_id]['image'] = sample.image_path

for tag in _SAMPLE_FIELD_NAMES:
try:
if ExplainLoader._is_attr_empty(sample, tag.value):
continue
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id)
elif tag == ExplainFieldsEnum.EXPLANATION:
self._import_explanation_from_event(sample, sample_id)
except UnknownError as ex:
logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex))

def _import_inference_from_event(self, event, sample_id):
"""Parse the inference event."""
inference = event.inference
self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob))
self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd))
self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low))
self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi))
self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob))
self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd))
self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low))
self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi))

if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']:
self._loader_info['uncertainty_enabled'] = True

def _import_explanation_from_event(self, event, sample_id):
"""Parse the explanation event."""
if self._samples[sample_id]['explanation'] is None:
self._samples[sample_id]['explanation'] = defaultdict(dict)
sample_explanation = self._samples[sample_id]['explanation']

for explanation_item in event.explanation:
explainer = explanation_item.explain_method
label = explanation_item.label
sample_explanation[explainer][label] = explanation_item.heatmap_path

def _clear_job(self):
"""Clear the cached data and update the time info of the loader."""
self._samples.clear()
self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime
self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime
self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time'])

def clear_inner_dict(outer_dict):
"""Clear the inner structured data of the given dict."""
for item in outer_dict.values():
item.clear()

map(clear_inner_dict, [self._metadata, self._benchmark])

@staticmethod
def _filter_files(filenames):
"""
Gets a list of summary files.

Args:
filenames (list[str]): File name list, like [filename1, filename2].

Returns:
list[str], filename list.
"""
return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")),
filenames))

@staticmethod
def _is_attr_empty(event, attr_name) -> bool:
if not getattr(event, attr_name):
return True
for item in getattr(event, attr_name):
if not isinstance(item, list) or item:
return False
return True

@staticmethod
def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool:
if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']):
logger.info('length of ground_truth_prob does not match the length of ground_truth_label'
'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.'
'sample_id is : %s.',
len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id)
return False
return True

@staticmethod
def _is_inference_valid(sample):
"""
Check whether the inference data is empty or have the same length.

If the probs have different length with the labels, it can be confusing when assigning each prob to label.
'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be
empty, so empty prob will pass the check.
"""
ground_truth_len = len(sample['ground_truth_label'])
for name in ['ground_truth_prob', 'ground_truth_prob_sd',
'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']:
if sample[name] and len(sample[name]) != ground_truth_len:
return False

predicted_len = len(sample['predicted_label'])
for name in ['predicted_prob', 'predicted_prob_sd',
'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']:
if sample[name] and len(sample[name]) != predicted_len:
return False
return True

@staticmethod
def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool:
if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']):
logger.info('length of predicted_probs does not match the length of predicted_labels'
'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.',
len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id)
return False
return True

@staticmethod
def _score_event_to_dict(label_score_event, metric) -> Dict:
"""Transfer metric scores per label to pre-defined structure."""
new_label_score_dict = defaultdict(dict)
for label_id, label_score in enumerate(label_score_event):
new_label_score_dict[label_id][metric] = label_score
return new_label_score_dict

+ 178
- 255
mindinsight/explainer/manager/explain_manager.py View File

@@ -17,17 +17,20 @@
import os
import threading
import time
from collections import OrderedDict
from datetime import datetime
from typing import Optional

from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import BaseEnum
from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_job import ExplainJob
from mindinsight.explainer.manager.explain_loader import ExplainLoader
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError

_MAX_LOADER_NUM = 3
_MAX_INTERVAL = 3
_MAX_LOADERS_NUM = 3


class _ExplainManagerStatus(BaseEnum):
@@ -43,245 +46,63 @@ class ExplainManager:

def __init__(self, summary_base_dir: str):
self._summary_base_dir = summary_base_dir
self._loader_pool = {}
self._deleted_ids = []
self._status = _ExplainManagerStatus.INIT.value
self._loader_pool = OrderedDict()
self._loading_status = _ExplainManagerStatus.INIT.value
self._status_mutex = threading.Lock()
self._loader_pool_mutex = threading.Lock()
self._max_loader_num = _MAX_LOADER_NUM
self._reload_interval = None
self._max_loaders_num = _MAX_LOADERS_NUM
self._summary_watcher = SummaryWatcher()

def _reload_data(self):
"""periodically load summary from file."""
while True:
try:
self._load_data()

if not self._reload_interval:
break
time.sleep(self._reload_interval)
except UnknownError as ex:
logger.exception(ex)
logger.error('Unknown Error raise when loading summary files, status: %r, and loader pool size is %r.'
'Detail: %s', self._status, len(self._loader_pool), str(ex))
self._status = _ExplainManagerStatus.INVALID.value

def _load_data(self):
"""Loading the summary in the given base directory."""
logger.info('Start to load data, reload interval: %r.', self._reload_interval)

with self._status_mutex:
if self._status == _ExplainManagerStatus.LOADING.value:
logger.info('Current status is %s, will ignore to load data.', self._status)
return

self._status = _ExplainManagerStatus.LOADING.value

try:
self._generate_loaders()
self._execute_load_data()
except Exception as ex:
raise UnknownError(ex)

if not self._loader_pool:
self._status = _ExplainManagerStatus.INVALID.value
else:
self._status = _ExplainManagerStatus.DONE.value

logger.info('Load event data end, status: %r, and loader pool size is %r',
self._status, len(self._loader_pool))

def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
"""update the update time of loader of given id."""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time

def _delete_loader(self, loader_id):
"""delete loader given loader_id"""
if self._loader_pool.get(loader_id, None) is not None:
self._loader_pool.pop(loader_id)
logger.debug('delete loader %s', loader_id)

def _add_loader(self, loader):
"""add loader to the loader_pool."""
if len(self._loader_pool) >= _MAX_LOADER_NUM:
delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1
sorted_loaders = sorted(
self._loader_pool.items(),
key=lambda x: x[1].latest_update_time)

for index in range(delete_num):
delete_loader_id = sorted_loaders[index][0]
self._delete_loader(delete_loader_id)
self._loader_pool.update({loader.loader_id: loader})

def _deal_loaders(self, latest_loaders):
""""update the loader pool."""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue

if (self._loader_pool[loader_id].latest_update_time
< loader.latest_update_time):
self._update_loader_latest_update_time(
loader_id, loader.latest_update_time)

@staticmethod
def _generate_loader_id(relative_path):
"""Generate loader id for given path"""
loader_id = relative_path
return loader_id

@staticmethod
def _generate_loader_name(relative_path):
"""Generate_loader name for given path."""
loader_name = relative_path
return loader_name

def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob:
"""Generate explain job from given relative path."""
current_dir = os.path.realpath(FileHandler.join(
self._summary_base_dir, relative_path
))
loader_id = self._generate_loader_id(relative_path)
loader = ExplainJob(
job_id=loader_id,
summary_dir=current_dir,
create_time=ExplainJob.get_create_time(current_dir),
latest_update_time=ExplainJob.get_update_time(current_dir))
return loader

def _generate_loaders(self):
"""Generate job loaders from the summary watcher."""
dir_map_mtime_dict = {}
loader_dict = {}
min_modify_time = None
_, summaries = SummaryWatcher().list_explain_directories(
self._summary_base_dir)

for item in summaries:
relative_path = item.get('relative_path')
modify_time = item.get('update_time').timestamp()
loader_id = self._generate_loader_id(relative_path)

loader = self._loader_pool.get(loader_id, None)
if loader is not None and loader.latest_update_time > modify_time:
modify_time = loader.latest_update_time

if min_modify_time is None:
min_modify_time = modify_time

if len(dir_map_mtime_dict) < _MAX_LOADER_NUM:
if modify_time < min_modify_time:
min_modify_time = modify_time
dir_map_mtime_dict.update({relative_path: modify_time})
else:
if modify_time >= min_modify_time:
dir_map_mtime_dict.update({relative_path: modify_time})

sorted_dir_tuple = sorted(dir_map_mtime_dict.items(),
key=lambda d: d[1])[-_MAX_LOADER_NUM:]

for relative_path, modify_time in sorted_dir_tuple:
loader_id = self._generate_loader_id(relative_path)
loader = self._generate_loader_by_relative_path(relative_path)
loader_dict.update({loader_id: loader})

sorted_loaders = sorted(loader_dict.items(),
key=lambda x: x[1].latest_update_time)
latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:]
self._deal_loaders(latest_loaders)

def _execute_loader(self, loader_id):
"""Execute the data loading."""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug('Loader %r has been deleted, will not load'
'data', loader_id)
return
loader.load()
@property
def summary_base_dir(self):
"""Return the base directory for summary records."""
return self._summary_base_dir

except MindInsightException as ex:
logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex)
with self._loader_pool_mutex:
self._delete_loader(loader_id)
def start_load_data(self, reload_interval: int = 0):
"""
Start individual thread to cache explain_jobs and loading summary data periodically.

def _execute_load_data(self):
"""Execute the loader in the pool to load data."""
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
self._execute_loader(loader_id)
Args:
reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded
once. Default: 0.
"""
thread = threading.Thread(target=self._repeat_loading,
name='start_load_thread',
args=(reload_interval,),
daemon=True)
time.sleep(1)
thread.start()

def _get_snapshot_loader_pool(self):
"""Get snapshot of loader_pool."""
with self._loader_pool_mutex:
return dict(self._loader_pool)
def get_job(self, loader_id: str) -> Optional[ExplainLoader]:
"""
Return ExplainLoader given loader_id.

def _check_status_valid(self):
"""Check manager status."""
if self._status == _ExplainManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status)
If explain job w.r.t given loader_id is not found, None will be returned.

@staticmethod
def _check_train_id_valid(train_id: str):
"""Verify the train_id is valid."""
if not train_id.startswith('./'):
logger.warning('train_id does not start with "./"')
return False

if len(train_id.split('/')) > 2:
logger.warning('train_id contains multiple "/"')
return False
return True

def _check_train_job_exist(self, train_id):
"""Verify thee train_job is existed given train_id."""
if train_id in self._loader_pool:
return
self._check_train_id_valid(train_id)
if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id):
return
raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id)
Args:
loader_id (str): The id of expected ExplainLoader

def _reload_data_again(self):
"""Reload the data one more time."""
logger.debug('Start to reload data again.')
thread = threading.Thread(target=self._load_data,
name='reload_data_thread')
thread.daemon = False
thread.start()
Return:
explain_job
"""
self._check_status_valid()

def _get_job(self, train_id):
"""Retrieve train_job given train_id."""
is_reload = False
with self._loader_pool_mutex:
loader = self._loader_pool.get(train_id, None)

if loader is None:
relative_path = train_id
temp_loader = self._generate_loader_by_relative_path(
relative_path)
if loader_id in self._loader_pool:
self._loader_pool[loader_id].query_time = datetime.now().timestamp()
self._loader_pool.move_to_end(loader_id, last=False)
return self._loader_pool[loader_id]

if temp_loader is None:
return None
self._add_loader(temp_loader)
is_reload = True
if is_reload:
self._reload_data_again()
try:
loader = self._generate_loader_from_relative_path(loader_id)
loader.query_time = datetime.now().timestamp()
self._add_loader(loader)
self._reload_data_again()
except ParamValueError:
logger.warning('Cannot find summary in path: %s. No explain_job will be returned.', loader_id)
return None
return loader

@property
def summary_base_dir(self):
"""Return the base directory for summary records."""
return self._summary_base_dir

def get_job_list(self, offset=0, limit=None):
"""
Return List of explain jobs. includes job ID, create and update time.
@@ -298,44 +119,146 @@ class ExplainManager:
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
"""
watcher = SummaryWatcher()
total, dir_infos = \
watcher.list_explain_directories(self._summary_base_dir,
offset=offset, limit=limit)
self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit)
return total, dir_infos

def get_job(self, train_id):
def _repeat_loading(self, repeat_interval):
"""Periodically loading summary."""
while True:
try:
logger.info('Start to load data, repeat interval: %r.', repeat_interval)
self._load_data()
if not repeat_interval:
return
time.sleep(repeat_interval)
except UnknownError as ex:
logger.exception(ex)
logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d'
'Detail: %s', self._loading_status, len(self._loader_pool), str(ex))

def _load_data(self):
"""
Prepare loaders in cache and start loading the data from summaries.

Only a limited number of loaders will be cached in terms of updated_time or query_time. The size of cache
pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the lastest _MAX_LOADER_NUM
summaries will be loaded in cache. If a cached loader if queries by 'get_job', the query_time of the loader
will be updated as well as the the loader moved to the end of cache. If an uncached summary is queried,
a new loader instance will be generated and put to the end cache.
"""
Return ExplainJob given train_id.
try:
with self._status_mutex:
if self._loading_status == _ExplainManagerStatus.LOADING.value:
logger.info('Current status is %s, will ignore to load data.', self._loading_status)
return

If explain job w.r.t given train_id is not found, None will be returned.
self._loading_status = _ExplainManagerStatus.LOADING.value

Args:
train_id (str): The id of expected ExplainJob
self._cache_loaders()
self._execute_loading()

Return:
explain_job
"""
self._check_status_valid()
self._check_train_job_exist(train_id)
if not self._loader_pool:
self._loading_status = _ExplainManagerStatus.INVALID.value
else:
self._loading_status = _ExplainManagerStatus.DONE.value

logger.info('Load event data end, status: %s, and loader pool size: %d',
self._loading_status, len(self._loader_pool))

except Exception as ex:
self._loading_status = _ExplainManagerStatus.INVALID.value
logger.exception(ex)
raise UnknownError(str(ex))

def _cache_loaders(self):
"""Cache explain loader in cache pool."""
dir_map_mtime_dict = []
_, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir)

for summary_info in summaries_info:
summary_path = summary_info.get('relative_path')
summary_update_time = summary_info.get('update_time').timestamp()

if summary_path in self._loader_pool:
summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time)

dir_map_mtime_dict.append((summary_info, summary_update_time))

sorted_summaries_info = sorted(dir_map_mtime_dict, key=lambda x: x[1])[-_MAX_LOADERS_NUM:]

loader = self._get_job(train_id)
if loader is None:
return None
with self._loader_pool_mutex:
for summary_info, query_time in sorted_summaries_info:
summary_path = summary_info['relative_path']
if summary_path not in self._loader_pool:
loader = self._generate_loader_from_relative_path(summary_path)
self._add_loader(loader)
else:
self._loader_pool[summary_path].query_time = query_time
self._loader_pool.move_to_end(summary_path, last=False)

def _generate_loader_from_relative_path(self, relative_path: str) -> ExplainLoader:
"""Generate explain loader from the given relative path."""
self._check_summary_exist(relative_path)
current_dir = os.path.realpath(FileHandler.join(self._summary_base_dir, relative_path))
loader_id = self._generate_loader_id(relative_path)
loader = ExplainLoader(loader_id=loader_id, summary_dir=current_dir)
return loader

def start_load_data(self, reload_interval=_MAX_INTERVAL):
"""
Start threads for loading data.
def _add_loader(self, loader):
"""add loader to the loader_pool."""
if loader.train_id not in self._loader_pool:
self._loader_pool[loader.train_id] = loader
else:
self._loader_pool.move_to_end(loader.train_id)

Args:
reload_interval (int): interval to reload the summary from file
"""
self._reload_interval = reload_interval
while len(self._loader_pool) > self._max_loaders_num:
self._loader_pool.popitem(last=False)

def _execute_loading(self):
"""Execute the data loading."""
for loader_id in list(self._loader_pool.keys()):
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug('Loader %r has been deleted, will not load data', loader_id)
return
loader.load()

except MindInsightException as ex:
logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex)
with self._loader_pool_mutex:
self._delete_loader(loader_id)

def _delete_loader(self, loader_id):
"""delete loader given loader_id"""
if loader_id in self._loader_pool:
self._loader_pool.pop(loader_id)
logger.debug('delete loader %s', loader_id)

thread = threading.Thread(target=self._reload_data, name='start_load_data_thread')
thread.daemon = True
def _check_status_valid(self):
"""Check manager status."""
if self._loading_status == _ExplainManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._loading_status)

def _check_summary_exist(self, loader_id):
"""Verify thee train_job is existed given loader_id."""
if not self._summary_watcher.is_summary_directory(self._summary_base_dir, loader_id):
raise ParamValueError('Can not find the train job in the manager.')

def _reload_data_again(self):
"""Reload the data one more time."""
logger.debug('Start to reload data again.')
thread = threading.Thread(target=self._load_data, name='reload_data_thread')
thread.daemon = False
thread.start()

# wait for data loading
time.sleep(1)
@staticmethod
def _generate_loader_id(relative_path):
"""Generate loader id for given path"""
loader_id = relative_path
return loader_id


EXPLAIN_MANAGER = ExplainManager(summary_base_dir=settings.SUMMARY_BASE_DIR)

+ 51
- 56
mindinsight/explainer/manager/explain_parser.py View File

@@ -17,49 +17,41 @@ File parser for MindExplain data.

This module is used to parse the MindExplain log file.
"""
import re
import collections
from collections import namedtuple

from google.protobuf.message import DecodeError

from mindinsight.datavisual.common import exceptions
from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.enums import ExplainFieldsEnum
from mindinsight.explainer.common.log import logger
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain
from mindinsight.utils.exceptions import UnknownError

HEADER_SIZE = 8
CRC_STR_SIZE = 4
MAX_EVENT_STRING = 500000000
BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status'])
MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status'])


class ImageDataContainer:
"""
Container for image data to allow pickling.

Args:
explain_message (Explain): Explain proto buffer message.
"""

def __init__(self, explain_message: Explain):
self.sample_id = explain_message.sample_id
self.image_path = explain_message.image_path
self.ground_truth_label = explain_message.ground_truth_label
self.inference = explain_message.inference
self.explanation = explain_message.explanation
self.status = explain_message.status


class _ExplainParser(_SummaryParser):
BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status'])
MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status'])
InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob',
'ground_truth_prob_sd',
'ground_truth_prob_itl95_low',
'ground_truth_prob_itl95_hi',
'predicted_label',
'predicted_prob',
'predicted_prob_sd',
'predicted_prob_itl95_low',
'predicted_prob_itl95_hi'])
SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference',
'explanation', 'status'])


class ExplainParser(_SummaryParser):
"""The summary file parser."""

def __init__(self, summary_dir):
super(_ExplainParser, self).__init__(summary_dir)
super(ExplainParser, self).__init__(summary_dir)
self._latest_filename = ''

def parse_explain(self, filenames):
@@ -71,8 +63,7 @@ class _ExplainParser(_SummaryParser):
Returns:
bool, True if all the summary files are finished loading.
"""
summary_files = self.filter_files(filenames)
summary_files = self.sort_files(summary_files)
summary_files = self.sort_files(filenames)

is_end = False
is_clean = False
@@ -125,20 +116,6 @@ class _ExplainParser(_SummaryParser):
logger.exception(ex)
raise UnknownError(str(ex))

def filter_files(self, filenames):
"""
Gets a list of summary files.

Args:
filenames (list[str]): File name list, like [filename1, filename2].

Returns:
list[str], filename list.
"""
return list(filter(
lambda filename: (re.search(r'summary\.\d+', filename)
and filename.endswith("_explain")), filenames))

@staticmethod
def _event_decode(event_str):
"""
@@ -153,9 +130,9 @@ class _ExplainParser(_SummaryParser):
logger.debug("Deserialize event string completed.")

fields = {
'sample_id': PluginNameEnum.SAMPLE_ID,
'benchmark': PluginNameEnum.BENCHMARK,
'metadata': PluginNameEnum.METADATA
'sample_id': ExplainFieldsEnum.SAMPLE_ID,
'benchmark': ExplainFieldsEnum.BENCHMARK,
'metadata': ExplainFieldsEnum.METADATA
}

tensor_event_value = getattr(event, 'explain')
@@ -163,19 +140,19 @@ class _ExplainParser(_SummaryParser):
field_list = []
tensor_value_list = []
for field in fields:
if not getattr(tensor_event_value, field):
if not getattr(tensor_event_value, field, False):
continue

if PluginNameEnum.METADATA.value == field and not tensor_event_value.metadata.label:
if ExplainFieldsEnum.METADATA.value == field and not tensor_event_value.metadata.label:
continue

tensor_value = None
if field == PluginNameEnum.SAMPLE_ID.value:
tensor_value = _ExplainParser._add_image_data(tensor_event_value)
elif field == PluginNameEnum.BENCHMARK.value:
tensor_value = _ExplainParser._add_benchmark(tensor_event_value)
elif field == PluginNameEnum.METADATA.value:
tensor_value = _ExplainParser._add_metadata(tensor_event_value)
if field == ExplainFieldsEnum.SAMPLE_ID.value:
tensor_value = ExplainParser._add_image_data(tensor_event_value)
elif field == ExplainFieldsEnum.BENCHMARK.value:
tensor_value = ExplainParser._add_benchmark(tensor_event_value)
elif field == ExplainFieldsEnum.METADATA.value:
tensor_value = ExplainParser._add_metadata(tensor_event_value)
logger.debug("Event generated, label is %s, step is %s.", field, event.step)
field_list.append(field)
tensor_value_list.append(tensor_value)
@@ -189,8 +166,26 @@ class _ExplainParser(_SummaryParser):
Args:
tensor_event_value: the object of Explain message
"""
image_data = ImageDataContainer(tensor_event_value)
return image_data
inference = InferfenceContainer(
ground_truth_prob=tensor_event_value.inference.ground_truth_prob,
ground_truth_prob_sd=tensor_event_value.inference.ground_truth_prob_sd,
ground_truth_prob_itl95_low=tensor_event_value.inference.ground_truth_prob_itl95_low,
ground_truth_prob_itl95_hi=tensor_event_value.inference.ground_truth_prob_itl95_hi,
predicted_label=tensor_event_value.inference.predicted_label,
predicted_prob=tensor_event_value.inference.predicted_prob,
predicted_prob_sd=tensor_event_value.inference.predicted_prob_sd,
predicted_prob_itl95_low=tensor_event_value.inference.predicted_prob_itl95_low,
predicted_prob_itl95_hi=tensor_event_value.inference.predicted_prob_itl95_hi
)
sample_data = SampleContainer(
sample_id=tensor_event_value.sample_id,
image_path=tensor_event_value.image_path,
ground_truth_label=tensor_event_value.ground_truth_label,
inference=inference,
explanation=tensor_event_value.explanation,
status=tensor_event_value.status
)
return sample_data

@staticmethod
def _add_benchmark(tensor_event_value):


+ 1
- 1
tests/ut/explainer/encapsulator/mock_explain_manager.py View File

@@ -25,7 +25,7 @@ class MockExplainJob:
self.create_time = datetime.timestamp(
datetime.strptime("2020-10-01 20:21:23",
ExplainJobEncap.DATETIME_FORMAT))
self.latest_update_time = self.create_time
self.update_time = self.create_time
self.sample_count = 1999
self.min_confidence = 0.5
self.explainers = ["Gradient"]


Loading…
Cancel
Save