diff --git a/mindinsight/backend/explainer/__init__.py b/mindinsight/backend/explainer/__init__.py index 565ef09c..9c07fef5 100644 --- a/mindinsight/backend/explainer/__init__.py +++ b/mindinsight/backend/explainer/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================ """Module init file.""" +from mindinsight.conf import settings from mindinsight.backend.explainer.explainer_api import init_module as init_query_module +from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER def init_module(app): @@ -27,3 +29,4 @@ def init_module(app): """ init_query_module(app) + EXPLAIN_MANAGER.start_load_data(reload_interval=settings.RELOAD_INTERVAL) diff --git a/mindinsight/backend/explainer/explainer_api.py b/mindinsight/backend/explainer/explainer_api.py index 8a3eb6fa..f17773be 100644 --- a/mindinsight/backend/explainer/explainer_api.py +++ b/mindinsight/backend/explainer/explainer_api.py @@ -29,32 +29,16 @@ from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.utils.tools import get_train_id -from mindinsight.explainer.manager.explain_manager import ExplainManager +from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap - -URL_PREFIX = settings.URL_PATH_PREFIX+settings.API_PREFIX +URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) -class ExplainManagerHolder: - """ExplainManger instance holder.""" - - static_instance = None - - @classmethod - def get_instance(cls): - return cls.static_instance - - @classmethod - def initialize(cls): - cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR) - cls.static_instance.start_load_data() - - def _image_url_formatter(train_id, image_path, image_type): """Returns image url.""" data = { @@ -91,7 +75,7 @@ def query_explain_jobs(): offset = Validation.check_offset(offset=offset) limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT) - encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) + encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) total, jobs = encapsulator.query_explain_jobs(offset, limit) return jsonify({ @@ -107,7 +91,7 @@ def query_explain_job(): train_id = get_train_id(request) if train_id is None: raise ParamMissError("train_id") - encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) + encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) metadata = encapsulator.query_meta(train_id) return jsonify(metadata) @@ -139,7 +123,7 @@ def query_saliency(): encapsulator = SaliencyEncap( _image_url_formatter, - ExplainManagerHolder.get_instance()) + EXPLAIN_MANAGER) count, samples = encapsulator.query_saliency_maps(train_id=train_id, labels=labels, explainers=explainers, @@ -160,7 +144,7 @@ def query_evaluation(): train_id = get_train_id(request) if train_id is None: raise ParamMissError("train_id") - encapsulator = EvaluationEncap(ExplainManagerHolder.get_instance()) + encapsulator = EvaluationEncap(EXPLAIN_MANAGER) scores = encapsulator.query_explainer_scores(train_id) return jsonify({ "explainer_scores": scores, @@ -182,7 +166,7 @@ def query_image(): if image_type not in ("original", "overlay"): raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") - encapsulator = DatafileEncap(ExplainManagerHolder.get_instance()) + encapsulator = DatafileEncap(EXPLAIN_MANAGER) image = encapsulator.query_image_binary(train_id, image_path, image_type) if image is None: raise ImageNotExistError(f"{image_path}") @@ -198,5 +182,4 @@ def init_module(app): app: the application obj. """ - ExplainManagerHolder.initialize() app.register_blueprint(BLUEPRINT) diff --git a/mindinsight/explainer/common/enums.py b/mindinsight/explainer/common/enums.py index 48d490f5..6afc8efb 100644 --- a/mindinsight/explainer/common/enums.py +++ b/mindinsight/explainer/common/enums.py @@ -31,7 +31,7 @@ class DataManagerStatus(BaseEnum): INVALID = 'INVALID' -class PluginNameEnum(BaseEnum): +class ExplainFieldsEnum(BaseEnum): """Plugin Name Enum.""" EXPLAIN = 'explain' SAMPLE_ID = 'sample_id' diff --git a/mindinsight/explainer/encapsulator/explain_job_encap.py b/mindinsight/explainer/encapsulator/explain_job_encap.py index 794b9b35..0b2bd206 100644 --- a/mindinsight/explainer/encapsulator/explain_job_encap.py +++ b/mindinsight/explainer/encapsulator/explain_job_encap.py @@ -70,7 +70,7 @@ class ExplainJobEncap(ExplainDataEncap): info["train_id"] = job.train_id info["create_time"] = datetime.fromtimestamp(job.create_time)\ .strftime(cls.DATETIME_FORMAT) - info["update_time"] = datetime.fromtimestamp(job.latest_update_time)\ + info["update_time"] = datetime.fromtimestamp(job.update_time)\ .strftime(cls.DATETIME_FORMAT) return info diff --git a/mindinsight/explainer/manager/event_parse.py b/mindinsight/explainer/manager/event_parse.py deleted file mode 100644 index 51321351..00000000 --- a/mindinsight/explainer/manager/event_parse.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""EventParser for summary event.""" -from collections import namedtuple, defaultdict -from typing import Dict, List, Optional, Tuple - -from mindinsight.explainer.common.enums import PluginNameEnum -from mindinsight.explainer.common.log import logger -from mindinsight.utils.exceptions import UnknownError - -_IMAGE_DATA_TAGS = { - 'sample_id': PluginNameEnum.SAMPLE_ID.value, - 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, - 'inference': PluginNameEnum.INFERENCE.value, - 'explanation': PluginNameEnum.EXPLANATION.value -} - -_NUM_DIGIT = 7 - - -class EventParser: - """Parser for event data.""" - - def __init__(self, job): - self._job = job - self._sample_pool = {} - - @staticmethod - def parse_metadata(metadata) -> Tuple[List, List, List]: - """Parse the metadata event.""" - explainers = list(metadata.explain_method) - metrics = list(metadata.benchmark_method) - labels = list(metadata.label) - return explainers, metrics, labels - - @staticmethod - def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]: - """Parse the benchmark event.""" - explainer_score_dict = defaultdict(list) - label_score_dict = defaultdict(dict) - - for benchmark in benchmarks: - explainer = benchmark.explain_method - metric = benchmark.benchmark_method - metric_score = benchmark.total_score - label_score_event = benchmark.label_score - - explainer_score_dict[explainer].append({ - 'metric': metric, - 'score': round(metric_score, _NUM_DIGIT)}) - new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric) - for label, label_scores in new_label_score_dict.items(): - label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores - - return explainer_score_dict, label_score_dict - - def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: - """Parse the sample event.""" - sample_id = sample.sample_id - - if sample_id not in self._sample_pool: - self._sample_pool[sample_id] = sample - return None - - for tag in _IMAGE_DATA_TAGS: - try: - if tag == PluginNameEnum.INFERENCE.value: - self._parse_inference(sample, sample_id) - elif tag == PluginNameEnum.EXPLANATION.value: - self._parse_explanation(sample, sample_id) - else: - self._parse_sample_info(sample, sample_id, tag) - except UnknownError as ex: - logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) - continue - - if EventParser._is_ready_for_display(self._sample_pool[sample_id]): - return self._sample_pool[sample_id] - return None - - def clear(self): - """Clear the loaded data.""" - self._sample_pool.clear() - - @staticmethod - def _is_ready_for_display(image_container: namedtuple) -> bool: - """ - Check whether the image_container is ready for frontend display. - - Args: - image_container (namedtuple): container consists of sample data - - Return: - bool: whether the image_container if ready for display - """ - required_attrs = ['image_path', 'ground_truth_label', 'inference'] - for attr in required_attrs: - if not EventParser.is_attr_ready(image_container, attr): - return False - return True - - @staticmethod - def is_attr_ready(image_container: namedtuple, attr: str) -> bool: - """ - Check whether the given attribute is ready in image_container. - - Args: - image_container (namedtuple): container consist of sample data - attr (str): attribute to check - - Returns: - bool, whether the attr is ready - """ - if getattr(image_container, attr, False): - return True - return False - - @staticmethod - def _score_event_to_dict(label_score_event, metric): - """Transfer metric scores per label to pre-defined structure.""" - new_label_score_dict = defaultdict(list) - for label_id, label_score in enumerate(label_score_event): - new_label_score_dict[label_id].append({ - 'metric': metric, - 'score': round(label_score, _NUM_DIGIT), - }) - return new_label_score_dict - - def _parse_inference(self, event, sample_id): - """Parse the inference event.""" - self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob) - self._sample_pool[sample_id].inference.ground_truth_prob_sd.extend(event.inference.ground_truth_prob_sd) - self._sample_pool[sample_id].inference.ground_truth_prob_itl95_low.\ - extend(event.inference.ground_truth_prob_itl95_low) - self._sample_pool[sample_id].inference.ground_truth_prob_itl95_hi.\ - extend(event.inference.ground_truth_prob_itl95_hi) - - self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label) - self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob) - self._sample_pool[sample_id].inference.predicted_prob_sd.extend(event.inference.predicted_prob_sd) - self._sample_pool[sample_id].inference.predicted_prob_itl95_low.extend(event.inference.predicted_prob_itl95_low) - self._sample_pool[sample_id].inference.predicted_prob_itl95_hi.extend(event.inference.predicted_prob_itl95_hi) - - def _parse_explanation(self, event, sample_id): - """Parse the explanation event.""" - if event.explanation: - for explanation_item in event.explanation: - new_explanation = self._sample_pool[sample_id].explanation.add() - new_explanation.explain_method = explanation_item.explain_method - new_explanation.label = explanation_item.label - new_explanation.heatmap_path = explanation_item.heatmap_path - - def _parse_sample_info(self, event, sample_id, tag): - """Parse the event containing image info.""" - if not getattr(self._sample_pool[sample_id], tag): - setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) diff --git a/mindinsight/explainer/manager/explain_job.py b/mindinsight/explainer/manager/explain_job.py deleted file mode 100644 index 3b5a96a2..00000000 --- a/mindinsight/explainer/manager/explain_job.py +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""ExplainJob.""" - -import os -from collections import defaultdict -from datetime import datetime -from typing import Union - -from mindinsight.explainer.common.enums import PluginNameEnum -from mindinsight.explainer.common.log import logger -from mindinsight.explainer.manager.explain_parser import _ExplainParser -from mindinsight.explainer.manager.event_parse import EventParser -from mindinsight.datavisual.data_access.file_handler import FileHandler -from mindinsight.datavisual.common.exceptions import TrainJobNotExistError - -_NUM_DIGIT = 7 - - -class ExplainJob: - """ExplainJob which manage the record in the summary file.""" - - def __init__(self, - job_id: str, - summary_dir: str, - create_time: float, - latest_update_time: float): - - self._job_id = job_id - self._summary_dir = summary_dir - self._parser = _ExplainParser(summary_dir) - - self._event_parser = EventParser(self) - self._latest_update_time = latest_update_time - self._create_time = create_time - self._uncertainty_enabled = False - self._labels = [] - self._metrics = [] - self._explainers = [] - self._samples_info = {} - self._labels_info = {} - self._explainer_score_dict = defaultdict(list) - self._label_score_dict = defaultdict(dict) - - @property - def all_classes(self): - """ - Return a list of label info - - Returns: - class_objs (List[ClassObj]): a list of class_objects, each object - contains: - - - id (int): label id - - label (str): label name - - sample_count (int): number of samples for each label - """ - all_classes_return = [] - for label_id, label_info in self._labels_info.items(): - single_info = { - 'id': label_id, - 'label': label_info['label'], - 'sample_count': len(label_info['sample_ids'])} - all_classes_return.append(single_info) - return all_classes_return - - @property - def explainers(self): - """ - Return a list of explainer names - - Returns: - list(str), explainer names - """ - return self._explainers - - @property - def explainer_scores(self): - """Return evaluation results for every explainer.""" - merged_scores = [] - for explainer, explainer_score_on_metric in self._explainer_score_dict.items(): - label_scores = [] - for label, label_score_on_metric in self._label_score_dict[explainer].items(): - score_single_label = { - 'label': self._labels[label], - 'evaluations': label_score_on_metric, - } - label_scores.append(score_single_label) - merged_scores.append({ - 'explainer': explainer, - 'evaluations': explainer_score_on_metric, - 'class_scores': label_scores, - }) - return merged_scores - - @property - def sample_count(self): - """ - Return total number of samples in the job. - - Return: - int, total number of samples - - """ - return len(self._samples_info) - - @property - def train_id(self): - """ - Return ID of explain job - - Returns: - str, id of ExplainJob object - """ - return self._job_id - - @property - def metrics(self): - """ - Return a list of metric names - - Returns: - list(str), metric names - """ - return self._metrics - - @property - def min_confidence(self): - """ - Return minimum confidence - - Returns: - min_confidence (float): - """ - return None - - @property - def uncertainty_enabled(self): - return self._uncertainty_enabled - - @property - def create_time(self): - """ - Return the create time of summary file - - Returns: - creation timestamp (float) - - """ - return self._create_time - - @property - def labels(self): - """Return the label contained in the job.""" - return self._labels - - @property - def latest_update_time(self): - """ - Return last modification time stamp of summary file. - - Returns: - float, last_modification_time stamp - """ - return self._latest_update_time - - @latest_update_time.setter - def latest_update_time(self, new_time: Union[float, datetime]): - """ - Update the latest_update_time timestamp manually. - - Args: - new_time stamp (union[float, datetime]): updated time for the job - """ - if isinstance(new_time, datetime): - self._latest_update_time = new_time.timestamp() - elif isinstance(new_time, float): - self._latest_update_time = new_time - else: - raise TypeError('new_time should have type of float or datetime') - - @property - def loader_id(self): - """Return the job id.""" - return self._job_id - - @property - def samples(self): - """Return the information of all samples in the job.""" - return self._samples_info - - @staticmethod - def get_create_time(file_path: str) -> float: - """Return timestamp of create time of specific path.""" - create_time = os.stat(file_path).st_ctime - return create_time - - @staticmethod - def get_update_time(file_path: str) -> float: - """Return timestamp of update time of specific path.""" - update_time = os.stat(file_path).st_mtime - return update_time - - def _initialize_labels_info(self): - """Initialize a dict for labels in the job.""" - if self._labels is None: - logger.warning('No labels is provided in job %s', self._job_id) - return - - for label_id, label in enumerate(self._labels): - self._labels_info[label_id] = {'label': label, - 'sample_ids': set()} - - def _explanation_to_dict(self, explanation): - """Transfer the explanation from event to dict storage.""" - explain_info = { - 'explainer': explanation.explain_method, - 'overlay': explanation.heatmap_path, - } - return explain_info - - def _image_container_to_dict(self, sample_data): - """Transfer the image container to dict storage.""" - has_uncertainty = False - sample_id = sample_data.sample_id - - sample_info = { - 'id': sample_id, - 'image': sample_data.image_path, - 'name': str(sample_id), - 'labels': [self._labels_info[x]['label'] - for x in sample_data.ground_truth_label], - 'inferences': []} - - ground_truth_labels = list(sample_data.ground_truth_label) - ground_truth_probs = list(sample_data.inference.ground_truth_prob) - predicted_labels = list(sample_data.inference.predicted_label) - predicted_probs = list(sample_data.inference.predicted_prob) - - if sample_data.inference.predicted_prob_sd or sample_data.inference.ground_truth_prob_sd: - ground_truth_prob_sds = list(sample_data.inference.ground_truth_prob_sd) - ground_truth_prob_lows = list(sample_data.inference.ground_truth_prob_itl95_low) - ground_truth_prob_his = list(sample_data.inference.ground_truth_prob_itl95_hi) - predicted_prob_sds = list(sample_data.inference.predicted_prob_sd) - predicted_prob_lows = list(sample_data.inference.predicted_prob_itl95_low) - predicted_prob_his = list(sample_data.inference.predicted_prob_itl95_hi) - has_uncertainty = True - else: - ground_truth_prob_sds = ground_truth_prob_lows = ground_truth_prob_his = None - predicted_prob_sds = predicted_prob_lows = predicted_prob_his = None - - inference_info = {} - for label, prob in zip( - ground_truth_labels + predicted_labels, - ground_truth_probs + predicted_probs): - inference_info[label] = { - 'label': self._labels_info[label]['label'], - 'confidence': round(prob, _NUM_DIGIT), - 'saliency_maps': []} - - if ground_truth_prob_sds or predicted_prob_sds: - for label, sd, low, hi in zip( - ground_truth_labels + predicted_labels, - ground_truth_prob_sds + predicted_prob_sds, - ground_truth_prob_lows + predicted_prob_lows, - ground_truth_prob_his + predicted_prob_his): - inference_info[label]['confidence_sd'] = sd - inference_info[label]['confidence_itl95'] = [low, hi] - - if EventParser.is_attr_ready(sample_data, 'explanation'): - for explanation in sample_data.explanation: - explanation_dict = self._explanation_to_dict(explanation) - inference_info[explanation.label]['saliency_maps'].append(explanation_dict) - - sample_info['inferences'] = list(inference_info.values()) - return sample_info, has_uncertainty - - def _import_sample(self, sample): - """Add sample object of given sample id.""" - for label_id in sample.ground_truth_label: - self._labels_info[label_id]['sample_ids'].add(sample.sample_id) - - sample_info, has_uncertainty = self._image_container_to_dict(sample) - self._samples_info.update({sample_info['id']: sample_info}) - self._uncertainty_enabled |= has_uncertainty - - def get_all_samples(self): - """ - Return a list of sample information cachced in the explain job - - Returns: - sample_list (List[SampleObj]): a list of sample objects, each object - consists of: - - - id (int): sample id - - name (str): basename of image - - labels (list[str]): list of labels - - inferences list[dict]) - """ - samples_in_list = list(self._samples_info.values()) - return samples_in_list - - def _is_metadata_empty(self): - """Check whether metadata is loaded first.""" - if not self._explainers or not self._metrics or not self._labels: - return True - return False - - def _import_data_from_event(self, event): - """Parse and import data from the event data.""" - tags = { - 'sample_id': PluginNameEnum.SAMPLE_ID, - 'benchmark': PluginNameEnum.BENCHMARK, - 'metadata': PluginNameEnum.METADATA - } - - if 'metadata' not in event and self._is_metadata_empty(): - raise ValueError('metadata is empty, should write metadata first in the summary.') - for tag in tags: - if tag not in event: - continue - - if tag == PluginNameEnum.SAMPLE_ID.value: - sample_event = event[tag] - sample_data = self._event_parser.parse_sample(sample_event) - if sample_data is not None: - self._import_sample(sample_data) - continue - - if tag == PluginNameEnum.BENCHMARK.value: - benchmark_event = event[tag].benchmark - explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event) - self._update_benchmark(explain_score_dict, label_score_dict) - - elif tag == PluginNameEnum.METADATA.value: - metadata_event = event[tag].metadata - metadata = EventParser.parse_metadata(metadata_event) - self._explainers, self._metrics, self._labels = metadata - self._initialize_labels_info() - - def load(self): - """ - Start loading data from parser. - """ - valid_file_names = [] - for filename in FileHandler.list_dir(self._summary_dir): - if FileHandler.is_file( - FileHandler.join(self._summary_dir, filename)): - valid_file_names.append(filename) - - if not valid_file_names: - raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) - - is_end = False - while not is_end: - is_clean, is_end, event = self._parser.parse_explain(valid_file_names) - - if is_clean: - logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) - self._clean_job() - - if event: - self._import_data_from_event(event) - - def _clean_job(self): - """Clean the cached data in job.""" - self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) - self._create_time = ExplainJob.get_update_time(self._summary_dir) - self._labels.clear() - self._metrics.clear() - self._explainers.clear() - self._samples_info.clear() - self._labels_info.clear() - self._explainer_score_dict.clear() - self._label_score_dict.clear() - self._event_parser.clear() - - def _update_benchmark(self, explainer_score_dict, labels_score_dict): - """Update the benchmark info.""" - for explainer, score in explainer_score_dict.items(): - self._explainer_score_dict[explainer].extend(score) - - for explainer, score in labels_score_dict.items(): - for label, score_of_label in score.items(): - self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, []) - + score_of_label) diff --git a/mindinsight/explainer/manager/explain_loader.py b/mindinsight/explainer/manager/explain_loader.py new file mode 100644 index 00000000..190b2165 --- /dev/null +++ b/mindinsight/explainer/manager/explain_loader.py @@ -0,0 +1,580 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ExplainLoader.""" + +import os +import re +from collections import defaultdict +from datetime import datetime +from typing import Dict, Iterable, List, Optional, Union + +from mindinsight.explainer.common.enums import ExplainFieldsEnum +from mindinsight.explainer.common.log import logger +from mindinsight.explainer.manager.explain_parser import ExplainParser +from mindinsight.datavisual.data_access.file_handler import FileHandler +from mindinsight.datavisual.common.exceptions import TrainJobNotExistError +from mindinsight.utils.exceptions import ParamValueError, UnknownError + +_NUM_DIGITS = 6 + +_EXPLAIN_FIELD_NAMES = [ + ExplainFieldsEnum.SAMPLE_ID, + ExplainFieldsEnum.BENCHMARK, + ExplainFieldsEnum.METADATA, +] + +_SAMPLE_FIELD_NAMES = [ + ExplainFieldsEnum.GROUND_TRUTH_LABEL, + ExplainFieldsEnum.INFERENCE, + ExplainFieldsEnum.EXPLANATION, +] + + +def _round(score): + """Take round of a number to given precision.""" + return round(score, _NUM_DIGITS) + + +class ExplainLoader: + """ExplainLoader which manage the record in the summary file.""" + + def __init__(self, + loader_id: str, + summary_dir: str): + + self._parser = ExplainParser(summary_dir) + + self._loader_info = { + 'loader_id': loader_id, + 'summary_dir': summary_dir, + 'create_time': os.stat(summary_dir).st_ctime, + 'update_time': os.stat(summary_dir).st_mtime, + 'query_time': os.stat(summary_dir).st_ctime, + 'uncertainty_enabled': False, + } + self._samples = defaultdict(dict) + self._metadata = {'explainers': [], 'metrics': [], 'labels': []} + self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} + + @property + def all_classes(self) -> List[Dict]: + """ + Return a list of detailed label information, including label id, label name and sample count of each label. + + Returns: + list[dict], a list of dict, each dict contains: + - id (int): label id + - label (str): label name + - sample_count (int): number of samples for each label + """ + sample_count_per_label = defaultdict(int) + samples_copy = self._samples.copy() + for sample in samples_copy.values(): + if sample.get('image', False) and sample.get('ground_truth_label', False): + for label in sample['ground_truth_label']: + sample_count_per_label[label] += 1 + + all_classes_return = [] + for label_id, label_name in enumerate(self._metadata['labels']): + single_info = { + 'id': label_id, + 'label': label_name, + 'sample_count': sample_count_per_label[label_id] + } + all_classes_return.append(single_info) + return all_classes_return + + @property + def query_time(self) -> float: + """Return query timestamp of explain loader.""" + return self._loader_info['query_time'] + + @query_time.setter + def query_time(self, new_time: Union[datetime, float]): + """ + Update the query_time timestamp manually. + + Args: + new_time (datetime.datetime or float): Updated query_time for the explain loader. + """ + if isinstance(new_time, datetime): + self._loader_info['query_time'] = new_time.timestamp() + elif isinstance(new_time, float): + self._loader_info['query_time'] = new_time + else: + raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' + .format(type(new_time))) + + @property + def create_time(self) -> float: + """Return the create timestamp of summary file.""" + return self._loader_info['create_time'] + + @create_time.setter + def create_time(self, new_time: Union[datetime, float]): + """ + Update the create_time manually + + Args: + new_time (datetime.datetime or float): Updated create_time of summary_file. + """ + if isinstance(new_time, datetime): + self._loader_info['create_time'] = new_time.timestamp() + elif isinstance(new_time, float): + self._loader_info['create_time'] = new_time + else: + raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' + .format(type(new_time))) + + @property + def explainers(self) -> List[str]: + """Return a list of explainer names recorded in the summary file.""" + return self._metadata['explainers'] + + @property + def explainer_scores(self) -> List[Dict]: + """ + Return evaluation results for every explainer. + + Returns: + list[dict], A list of evaluation results of each explainer. Each item contains: + - explainer (str): Name of evaluated explainer. + - evaluations (list[dict]): A list of evlauation results by different metrics. + - class_scores (list[dict]): A list of evaluation results on different labels. + + Each item in the evaluations contains: + - metric (str): name of metric method + - score (float): evaluation result + + Each item in the class_scores contains: + - label (str): Name of label + - evaluations (list[dict]): A list of evalution results on different labels by different metrics. + + Each item in evaluations contains: + - metric (str): Name of metric method + - score (float): Evaluation scores of explainer on specific label by the metric. + """ + explainer_scores = [] + for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items(): + metric_scores = [{'metric': metric, 'score': _round(score)} + for metric, score in explainer_score_on_metric.items()] + label_scores = [] + for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items(): + score_of_single_label = { + 'label': self._metadata['labels'][label], + 'evaluations': [ + {'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items() + ], + } + label_scores.append(score_of_single_label) + explainer_scores.append({ + 'explainer': explainer, + 'evaluations': metric_scores, + 'class_scores': label_scores, + }) + return explainer_scores + + @property + def labels(self) -> List[str]: + """Return the label recorded in the summary.""" + return self._metadata['labels'] + + @property + def metrics(self) -> List[str]: + """Return a list of metric names recorded in the summary file.""" + return self._metadata['metrics'] + + @property + def min_confidence(self) -> Optional[float]: + """Return minimum confidence used to filter the predicted labels.""" + return None + + @property + def sample_count(self) -> int: + """ + Return total number of samples in the loader. + + Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in + cache), the returned count only takes the available samples into account. + + Return: + int, total number of available samples in the loading job. + + """ + sample_count = 0 + samples_copy = self._samples.copy() + for sample in samples_copy.values(): + if sample.get('image', False) and sample.get('ground_truth_label', False): + sample_count += 1 + return sample_count + + @property + def samples(self) -> List[Dict]: + """Return the information of all samples in the job.""" + return self.get_all_samples() + + @property + def train_id(self) -> str: + """Return ID of explain loader.""" + return self._loader_info['loader_id'] + + @property + def uncertainty_enabled(self): + """Whethter uncertainty is enabled.""" + return self._loader_info['uncertainty_enabled'] + + @property + def update_time(self) -> float: + """Return latest modification timestamp of summary file.""" + return self._loader_info['update_time'] + + @update_time.setter + def update_time(self, new_time: Union[datetime, float]): + """ + Update the update_time manually. + + Args: + new_time stamp (datetime.datetime or float): Updated time for the summary file. + """ + if isinstance(new_time, datetime): + self._loader_info['update_time'] = new_time.timestamp() + elif isinstance(new_time, float): + self._loader_info['update_time'] = new_time + else: + raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' + .format(type(new_time))) + + def load(self): + """Start loading data from the latest summary file to the loader.""" + filenames = [] + for filename in FileHandler.list_dir(self._loader_info['summary_dir']): + if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)): + filenames.append(filename) + filenames = ExplainLoader._filter_files(filenames) + + if not filenames: + raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' + % self._loader_info['summary_dir']) + + is_end = False + while not is_end: + is_clean, is_end, event_dict = self._parser.parse_explain(filenames) + + if is_clean: + logger.info('Summary file in %s update, reload the data in the summary.', + self._loader_info['summary_dir']) + self._clear_job() + if event_dict: + self._import_data_from_event(event_dict) + + def get_all_samples(self) -> List[Dict]: + """ + Return a list of sample information cachced in the explain job + + Returns: + sample_list (List[SampleObj]): a list of sample objects, each object + consists of: + + - id (int): sample id + - name (str): basename of image + - labels (list[str]): list of labels + - inferences list[dict]) + """ + returned_samples = [] + samples_copy = self._samples.copy() + for sample_id, sample_info in samples_copy.items(): + if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): + continue + returned_sample = { + 'id': sample_id, + 'name': str(sample_id), + 'image': sample_info['image'], + 'labels': sample_info['ground_truth_label'], + } + + if not ExplainLoader._is_inference_valid(sample_info): + continue + + inferences = {} + for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], + sample_info['ground_truth_prob'] + sample_info['predicted_prob']): + inferences[label] = { + 'label': self._metadata['labels'][label], + 'confidence': _round(prob), + 'saliency_maps': [] + } + + if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: + for label, std, low, high in zip( + sample_info['ground_truth_label'] + sample_info['predicted_label'], + sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], + sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], + sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] + ): + inferences[label]['confidence_sd'] = _round(std) + inferences[label]['confidence_itl95'] = [_round(low), _round(high)] + + for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): + for label, heatmap_path in label_heatmap_path_dict.items(): + if label in inferences: + inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) + + returned_sample['inferences'] = list(inferences.values()) + returned_samples.append(returned_sample) + return returned_samples + + def _import_data_from_event(self, event_dict: Dict): + """Parse and import data from the event data.""" + if 'metadata' not in event_dict and self._is_metadata_empty(): + raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.') + + for tag, event in event_dict.items(): + if tag == ExplainFieldsEnum.METADATA.value: + self._import_metadata_from_event(event.metadata) + elif tag == ExplainFieldsEnum.BENCHMARK.value: + self._import_benchmark_from_event(event.benchmark) + elif tag == ExplainFieldsEnum.SAMPLE_ID.value: + self._import_sample_from_event(event) + else: + logger.info('Unknown ExplainField: %s', tag) + + def _is_metadata_empty(self): + """Check whether metadata is completely loaded first.""" + if not self._metadata['labels']: + return True + return False + + def _import_metadata_from_event(self, metadata_event): + """Import the metadata from event into loader.""" + + def take_union(existed_list, imported_data): + """Take union of existed_list and imported_data.""" + if isinstance(imported_data, Iterable): + for sample in imported_data: + if sample not in existed_list: + existed_list.append(sample) + + take_union(self._metadata['explainers'], metadata_event.explain_method) + take_union(self._metadata['metrics'], metadata_event.benchmark_method) + take_union(self._metadata['labels'], metadata_event.label) + + def _import_benchmark_from_event(self, benchmarks): + """ + Parse the benchmark event. + + Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall + evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results + w.r.t different labels. + + The structure of self._benchmark['explainer_score'] demonstrates below: + { + explainer_1: {metric_name_1: score_1, ...}, + explainer_2: {metric_name_1: score_1, ...}, + ... + } + + The structure of self._benchmark['label_score'] is: + { + explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, + explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, + ... + } + + Args: + benchmarks (benchmark_container): Parsed benchmarks data from summary file. + """ + explainer_score = self._benchmark['explainer_score'] + label_score = self._benchmark['label_score'] + + for benchmark in benchmarks: + explainer = benchmark.explain_method + metric = benchmark.benchmark_method + metric_score = benchmark.total_score + label_score_event = benchmark.label_score + + explainer_score[explainer][metric] = metric_score + new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) + for label, scores_of_metric in new_label_score_dict.items(): + if label not in label_score[explainer]: + label_score[explainer][label] = {} + label_score[explainer][label].update(scores_of_metric) + + def _import_sample_from_event(self, sample): + """ + Parse the sample event. + + Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored + in the following structure. + + - ground_truth_labels (list[int]): A list of ground truth labels of the sample. + - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. + - predicted_labels (list[int]): A list of predicted labels from the black-box model. + - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. + - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary + of saliency maps. The structure of explanations demonstrates below: + + { + explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, + explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, + ... + } + """ + if not getattr(sample, 'sample_id', None): + raise ParamValueError('sample_event has no sample_id') + sample_id = sample.sample_id + samples_copy = self._samples.copy() + if sample_id not in samples_copy: + self._samples[sample_id] = { + 'ground_truth_label': [], + 'ground_truth_prob': [], + 'ground_truth_prob_sd': [], + 'ground_truth_prob_itl95_low': [], + 'ground_truth_prob_itl95_hi': [], + 'predicted_label': [], + 'predicted_prob': [], + 'predicted_prob_sd': [], + 'predicted_prob_itl95_low': [], + 'predicted_prob_itl95_hi': [], + 'explanation': defaultdict(dict) + } + + if sample.image_path: + self._samples[sample_id]['image'] = sample.image_path + + for tag in _SAMPLE_FIELD_NAMES: + try: + if ExplainLoader._is_attr_empty(sample, tag.value): + continue + if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: + self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) + elif tag == ExplainFieldsEnum.INFERENCE: + self._import_inference_from_event(sample, sample_id) + elif tag == ExplainFieldsEnum.EXPLANATION: + self._import_explanation_from_event(sample, sample_id) + except UnknownError as ex: + logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) + + def _import_inference_from_event(self, event, sample_id): + """Parse the inference event.""" + inference = event.inference + self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) + self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) + self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) + self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) + self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) + self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) + self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) + self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) + self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) + + if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: + self._loader_info['uncertainty_enabled'] = True + + def _import_explanation_from_event(self, event, sample_id): + """Parse the explanation event.""" + if self._samples[sample_id]['explanation'] is None: + self._samples[sample_id]['explanation'] = defaultdict(dict) + sample_explanation = self._samples[sample_id]['explanation'] + + for explanation_item in event.explanation: + explainer = explanation_item.explain_method + label = explanation_item.label + sample_explanation[explainer][label] = explanation_item.heatmap_path + + def _clear_job(self): + """Clear the cached data and update the time info of the loader.""" + self._samples.clear() + self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime + self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime + self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time']) + + def clear_inner_dict(outer_dict): + """Clear the inner structured data of the given dict.""" + for item in outer_dict.values(): + item.clear() + + map(clear_inner_dict, [self._metadata, self._benchmark]) + + @staticmethod + def _filter_files(filenames): + """ + Gets a list of summary files. + + Args: + filenames (list[str]): File name list, like [filename1, filename2]. + + Returns: + list[str], filename list. + """ + return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), + filenames)) + + @staticmethod + def _is_attr_empty(event, attr_name) -> bool: + if not getattr(event, attr_name): + return True + for item in getattr(event, attr_name): + if not isinstance(item, list) or item: + return False + return True + + @staticmethod + def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool: + if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']): + logger.info('length of ground_truth_prob does not match the length of ground_truth_label' + 'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.' + 'sample_id is : %s.', + len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id) + return False + return True + + @staticmethod + def _is_inference_valid(sample): + """ + Check whether the inference data is empty or have the same length. + + If the probs have different length with the labels, it can be confusing when assigning each prob to label. + 'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be + empty, so empty prob will pass the check. + """ + ground_truth_len = len(sample['ground_truth_label']) + for name in ['ground_truth_prob', 'ground_truth_prob_sd', + 'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']: + if sample[name] and len(sample[name]) != ground_truth_len: + return False + + predicted_len = len(sample['predicted_label']) + for name in ['predicted_prob', 'predicted_prob_sd', + 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']: + if sample[name] and len(sample[name]) != predicted_len: + return False + return True + + @staticmethod + def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool: + if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']): + logger.info('length of predicted_probs does not match the length of predicted_labels' + 'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.', + len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id) + return False + return True + + @staticmethod + def _score_event_to_dict(label_score_event, metric) -> Dict: + """Transfer metric scores per label to pre-defined structure.""" + new_label_score_dict = defaultdict(dict) + for label_id, label_score in enumerate(label_score_event): + new_label_score_dict[label_id][metric] = label_score + return new_label_score_dict diff --git a/mindinsight/explainer/manager/explain_manager.py b/mindinsight/explainer/manager/explain_manager.py index dcda34f0..73756ca5 100644 --- a/mindinsight/explainer/manager/explain_manager.py +++ b/mindinsight/explainer/manager/explain_manager.py @@ -17,17 +17,20 @@ import os import threading import time +from collections import OrderedDict +from datetime import datetime +from typing import Optional +from mindinsight.conf import settings from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import BaseEnum from mindinsight.explainer.common.log import logger -from mindinsight.explainer.manager.explain_job import ExplainJob +from mindinsight.explainer.manager.explain_loader import ExplainLoader from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError -_MAX_LOADER_NUM = 3 -_MAX_INTERVAL = 3 +_MAX_LOADERS_NUM = 3 class _ExplainManagerStatus(BaseEnum): @@ -43,245 +46,63 @@ class ExplainManager: def __init__(self, summary_base_dir: str): self._summary_base_dir = summary_base_dir - self._loader_pool = {} - self._deleted_ids = [] - self._status = _ExplainManagerStatus.INIT.value + self._loader_pool = OrderedDict() + self._loading_status = _ExplainManagerStatus.INIT.value self._status_mutex = threading.Lock() self._loader_pool_mutex = threading.Lock() - self._max_loader_num = _MAX_LOADER_NUM - self._reload_interval = None + self._max_loaders_num = _MAX_LOADERS_NUM + self._summary_watcher = SummaryWatcher() - def _reload_data(self): - """periodically load summary from file.""" - while True: - try: - self._load_data() - - if not self._reload_interval: - break - time.sleep(self._reload_interval) - except UnknownError as ex: - logger.exception(ex) - logger.error('Unknown Error raise when loading summary files, status: %r, and loader pool size is %r.' - 'Detail: %s', self._status, len(self._loader_pool), str(ex)) - self._status = _ExplainManagerStatus.INVALID.value - - def _load_data(self): - """Loading the summary in the given base directory.""" - logger.info('Start to load data, reload interval: %r.', self._reload_interval) - - with self._status_mutex: - if self._status == _ExplainManagerStatus.LOADING.value: - logger.info('Current status is %s, will ignore to load data.', self._status) - return - - self._status = _ExplainManagerStatus.LOADING.value - - try: - self._generate_loaders() - self._execute_load_data() - except Exception as ex: - raise UnknownError(ex) - - if not self._loader_pool: - self._status = _ExplainManagerStatus.INVALID.value - else: - self._status = _ExplainManagerStatus.DONE.value - - logger.info('Load event data end, status: %r, and loader pool size is %r', - self._status, len(self._loader_pool)) - - def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): - """update the update time of loader of given id.""" - if latest_update_time is None: - latest_update_time = time.time() - self._loader_pool[loader_id].latest_update_time = latest_update_time - - def _delete_loader(self, loader_id): - """delete loader given loader_id""" - if self._loader_pool.get(loader_id, None) is not None: - self._loader_pool.pop(loader_id) - logger.debug('delete loader %s', loader_id) - - def _add_loader(self, loader): - """add loader to the loader_pool.""" - if len(self._loader_pool) >= _MAX_LOADER_NUM: - delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 - sorted_loaders = sorted( - self._loader_pool.items(), - key=lambda x: x[1].latest_update_time) - - for index in range(delete_num): - delete_loader_id = sorted_loaders[index][0] - self._delete_loader(delete_loader_id) - self._loader_pool.update({loader.loader_id: loader}) - - def _deal_loaders(self, latest_loaders): - """"update the loader pool.""" - with self._loader_pool_mutex: - for loader_id, loader in latest_loaders: - if self._loader_pool.get(loader_id, None) is None: - self._add_loader(loader) - continue - - if (self._loader_pool[loader_id].latest_update_time - < loader.latest_update_time): - self._update_loader_latest_update_time( - loader_id, loader.latest_update_time) - - @staticmethod - def _generate_loader_id(relative_path): - """Generate loader id for given path""" - loader_id = relative_path - return loader_id - - @staticmethod - def _generate_loader_name(relative_path): - """Generate_loader name for given path.""" - loader_name = relative_path - return loader_name - - def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: - """Generate explain job from given relative path.""" - current_dir = os.path.realpath(FileHandler.join( - self._summary_base_dir, relative_path - )) - loader_id = self._generate_loader_id(relative_path) - loader = ExplainJob( - job_id=loader_id, - summary_dir=current_dir, - create_time=ExplainJob.get_create_time(current_dir), - latest_update_time=ExplainJob.get_update_time(current_dir)) - return loader - - def _generate_loaders(self): - """Generate job loaders from the summary watcher.""" - dir_map_mtime_dict = {} - loader_dict = {} - min_modify_time = None - _, summaries = SummaryWatcher().list_explain_directories( - self._summary_base_dir) - - for item in summaries: - relative_path = item.get('relative_path') - modify_time = item.get('update_time').timestamp() - loader_id = self._generate_loader_id(relative_path) - - loader = self._loader_pool.get(loader_id, None) - if loader is not None and loader.latest_update_time > modify_time: - modify_time = loader.latest_update_time - - if min_modify_time is None: - min_modify_time = modify_time - - if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: - if modify_time < min_modify_time: - min_modify_time = modify_time - dir_map_mtime_dict.update({relative_path: modify_time}) - else: - if modify_time >= min_modify_time: - dir_map_mtime_dict.update({relative_path: modify_time}) - - sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), - key=lambda d: d[1])[-_MAX_LOADER_NUM:] - - for relative_path, modify_time in sorted_dir_tuple: - loader_id = self._generate_loader_id(relative_path) - loader = self._generate_loader_by_relative_path(relative_path) - loader_dict.update({loader_id: loader}) - - sorted_loaders = sorted(loader_dict.items(), - key=lambda x: x[1].latest_update_time) - latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] - self._deal_loaders(latest_loaders) - - def _execute_loader(self, loader_id): - """Execute the data loading.""" - try: - with self._loader_pool_mutex: - loader = self._loader_pool.get(loader_id, None) - if loader is None: - logger.debug('Loader %r has been deleted, will not load' - 'data', loader_id) - return - loader.load() + @property + def summary_base_dir(self): + """Return the base directory for summary records.""" + return self._summary_base_dir - except MindInsightException as ex: - logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) - with self._loader_pool_mutex: - self._delete_loader(loader_id) + def start_load_data(self, reload_interval: int = 0): + """ + Start individual thread to cache explain_jobs and loading summary data periodically. - def _execute_load_data(self): - """Execute the loader in the pool to load data.""" - loader_pool = self._get_snapshot_loader_pool() - for loader_id in loader_pool: - self._execute_loader(loader_id) + Args: + reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded + once. Default: 0. + """ + thread = threading.Thread(target=self._repeat_loading, + name='start_load_thread', + args=(reload_interval,), + daemon=True) + time.sleep(1) + thread.start() - def _get_snapshot_loader_pool(self): - """Get snapshot of loader_pool.""" - with self._loader_pool_mutex: - return dict(self._loader_pool) + def get_job(self, loader_id: str) -> Optional[ExplainLoader]: + """ + Return ExplainLoader given loader_id. - def _check_status_valid(self): - """Check manager status.""" - if self._status == _ExplainManagerStatus.INIT.value: - raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) + If explain job w.r.t given loader_id is not found, None will be returned. - @staticmethod - def _check_train_id_valid(train_id: str): - """Verify the train_id is valid.""" - if not train_id.startswith('./'): - logger.warning('train_id does not start with "./"') - return False - - if len(train_id.split('/')) > 2: - logger.warning('train_id contains multiple "/"') - return False - return True - - def _check_train_job_exist(self, train_id): - """Verify thee train_job is existed given train_id.""" - if train_id in self._loader_pool: - return - self._check_train_id_valid(train_id) - if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): - return - raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) + Args: + loader_id (str): The id of expected ExplainLoader - def _reload_data_again(self): - """Reload the data one more time.""" - logger.debug('Start to reload data again.') - thread = threading.Thread(target=self._load_data, - name='reload_data_thread') - thread.daemon = False - thread.start() + Return: + explain_job + """ + self._check_status_valid() - def _get_job(self, train_id): - """Retrieve train_job given train_id.""" - is_reload = False with self._loader_pool_mutex: - loader = self._loader_pool.get(train_id, None) - - if loader is None: - relative_path = train_id - temp_loader = self._generate_loader_by_relative_path( - relative_path) + if loader_id in self._loader_pool: + self._loader_pool[loader_id].query_time = datetime.now().timestamp() + self._loader_pool.move_to_end(loader_id, last=False) + return self._loader_pool[loader_id] - if temp_loader is None: - return None - - self._add_loader(temp_loader) - is_reload = True - - if is_reload: - self._reload_data_again() + try: + loader = self._generate_loader_from_relative_path(loader_id) + loader.query_time = datetime.now().timestamp() + self._add_loader(loader) + self._reload_data_again() + except ParamValueError: + logger.warning('Cannot find summary in path: %s. No explain_job will be returned.', loader_id) + return None return loader - @property - def summary_base_dir(self): - """Return the base directory for summary records.""" - return self._summary_base_dir - def get_job_list(self, offset=0, limit=None): """ Return List of explain jobs. includes job ID, create and update time. @@ -298,44 +119,146 @@ class ExplainManager: - create_time (datetime): Creation time of summary file. - update_time (datetime): Modification time of summary file. """ - watcher = SummaryWatcher() total, dir_infos = \ - watcher.list_explain_directories(self._summary_base_dir, - offset=offset, limit=limit) + self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) return total, dir_infos - def get_job(self, train_id): + def _repeat_loading(self, repeat_interval): + """Periodically loading summary.""" + while True: + try: + logger.info('Start to load data, repeat interval: %r.', repeat_interval) + self._load_data() + if not repeat_interval: + return + time.sleep(repeat_interval) + except UnknownError as ex: + logger.exception(ex) + logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d' + 'Detail: %s', self._loading_status, len(self._loader_pool), str(ex)) + + def _load_data(self): + """ + Prepare loaders in cache and start loading the data from summaries. + + Only a limited number of loaders will be cached in terms of updated_time or query_time. The size of cache + pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the lastest _MAX_LOADER_NUM + summaries will be loaded in cache. If a cached loader if queries by 'get_job', the query_time of the loader + will be updated as well as the the loader moved to the end of cache. If an uncached summary is queried, + a new loader instance will be generated and put to the end cache. """ - Return ExplainJob given train_id. + try: + with self._status_mutex: + if self._loading_status == _ExplainManagerStatus.LOADING.value: + logger.info('Current status is %s, will ignore to load data.', self._loading_status) + return - If explain job w.r.t given train_id is not found, None will be returned. + self._loading_status = _ExplainManagerStatus.LOADING.value - Args: - train_id (str): The id of expected ExplainJob + self._cache_loaders() + self._execute_loading() - Return: - explain_job - """ - self._check_status_valid() - self._check_train_job_exist(train_id) + if not self._loader_pool: + self._loading_status = _ExplainManagerStatus.INVALID.value + else: + self._loading_status = _ExplainManagerStatus.DONE.value + + logger.info('Load event data end, status: %s, and loader pool size: %d', + self._loading_status, len(self._loader_pool)) + + except Exception as ex: + self._loading_status = _ExplainManagerStatus.INVALID.value + logger.exception(ex) + raise UnknownError(str(ex)) + + def _cache_loaders(self): + """Cache explain loader in cache pool.""" + dir_map_mtime_dict = [] + _, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir) + + for summary_info in summaries_info: + summary_path = summary_info.get('relative_path') + summary_update_time = summary_info.get('update_time').timestamp() + + if summary_path in self._loader_pool: + summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time) + + dir_map_mtime_dict.append((summary_info, summary_update_time)) + + sorted_summaries_info = sorted(dir_map_mtime_dict, key=lambda x: x[1])[-_MAX_LOADERS_NUM:] - loader = self._get_job(train_id) - if loader is None: - return None + with self._loader_pool_mutex: + for summary_info, query_time in sorted_summaries_info: + summary_path = summary_info['relative_path'] + if summary_path not in self._loader_pool: + loader = self._generate_loader_from_relative_path(summary_path) + self._add_loader(loader) + else: + self._loader_pool[summary_path].query_time = query_time + self._loader_pool.move_to_end(summary_path, last=False) + + def _generate_loader_from_relative_path(self, relative_path: str) -> ExplainLoader: + """Generate explain loader from the given relative path.""" + self._check_summary_exist(relative_path) + current_dir = os.path.realpath(FileHandler.join(self._summary_base_dir, relative_path)) + loader_id = self._generate_loader_id(relative_path) + loader = ExplainLoader(loader_id=loader_id, summary_dir=current_dir) return loader - def start_load_data(self, reload_interval=_MAX_INTERVAL): - """ - Start threads for loading data. + def _add_loader(self, loader): + """add loader to the loader_pool.""" + if loader.train_id not in self._loader_pool: + self._loader_pool[loader.train_id] = loader + else: + self._loader_pool.move_to_end(loader.train_id) - Args: - reload_interval (int): interval to reload the summary from file - """ - self._reload_interval = reload_interval + while len(self._loader_pool) > self._max_loaders_num: + self._loader_pool.popitem(last=False) + + def _execute_loading(self): + """Execute the data loading.""" + for loader_id in list(self._loader_pool.keys()): + try: + with self._loader_pool_mutex: + loader = self._loader_pool.get(loader_id, None) + if loader is None: + logger.debug('Loader %r has been deleted, will not load data', loader_id) + return + loader.load() + + except MindInsightException as ex: + logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) + with self._loader_pool_mutex: + self._delete_loader(loader_id) + + def _delete_loader(self, loader_id): + """delete loader given loader_id""" + if loader_id in self._loader_pool: + self._loader_pool.pop(loader_id) + logger.debug('delete loader %s', loader_id) - thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') - thread.daemon = True + def _check_status_valid(self): + """Check manager status.""" + if self._loading_status == _ExplainManagerStatus.INIT.value: + raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._loading_status) + + def _check_summary_exist(self, loader_id): + """Verify thee train_job is existed given loader_id.""" + if not self._summary_watcher.is_summary_directory(self._summary_base_dir, loader_id): + raise ParamValueError('Can not find the train job in the manager.') + + def _reload_data_again(self): + """Reload the data one more time.""" + logger.debug('Start to reload data again.') + thread = threading.Thread(target=self._load_data, name='reload_data_thread') + thread.daemon = False thread.start() - # wait for data loading - time.sleep(1) + @staticmethod + def _generate_loader_id(relative_path): + """Generate loader id for given path""" + loader_id = relative_path + return loader_id + + +EXPLAIN_MANAGER = ExplainManager(summary_base_dir=settings.SUMMARY_BASE_DIR) diff --git a/mindinsight/explainer/manager/explain_parser.py b/mindinsight/explainer/manager/explain_parser.py index 456db89e..ea9b5a98 100644 --- a/mindinsight/explainer/manager/explain_parser.py +++ b/mindinsight/explainer/manager/explain_parser.py @@ -17,49 +17,41 @@ File parser for MindExplain data. This module is used to parse the MindExplain log file. """ -import re -import collections +from collections import namedtuple from google.protobuf.message import DecodeError from mindinsight.datavisual.common import exceptions -from mindinsight.explainer.common.enums import PluginNameEnum +from mindinsight.explainer.common.enums import ExplainFieldsEnum from mindinsight.explainer.common.log import logger from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 -from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain from mindinsight.utils.exceptions import UnknownError HEADER_SIZE = 8 CRC_STR_SIZE = 4 MAX_EVENT_STRING = 500000000 -BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) -MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) - - -class ImageDataContainer: - """ - Container for image data to allow pickling. - - Args: - explain_message (Explain): Explain proto buffer message. - """ - - def __init__(self, explain_message: Explain): - self.sample_id = explain_message.sample_id - self.image_path = explain_message.image_path - self.ground_truth_label = explain_message.ground_truth_label - self.inference = explain_message.inference - self.explanation = explain_message.explanation - self.status = explain_message.status - - -class _ExplainParser(_SummaryParser): +BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) +MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) +InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', + 'ground_truth_prob_sd', + 'ground_truth_prob_itl95_low', + 'ground_truth_prob_itl95_hi', + 'predicted_label', + 'predicted_prob', + 'predicted_prob_sd', + 'predicted_prob_itl95_low', + 'predicted_prob_itl95_hi']) +SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', + 'explanation', 'status']) + + +class ExplainParser(_SummaryParser): """The summary file parser.""" def __init__(self, summary_dir): - super(_ExplainParser, self).__init__(summary_dir) + super(ExplainParser, self).__init__(summary_dir) self._latest_filename = '' def parse_explain(self, filenames): @@ -71,8 +63,7 @@ class _ExplainParser(_SummaryParser): Returns: bool, True if all the summary files are finished loading. """ - summary_files = self.filter_files(filenames) - summary_files = self.sort_files(summary_files) + summary_files = self.sort_files(filenames) is_end = False is_clean = False @@ -125,20 +116,6 @@ class _ExplainParser(_SummaryParser): logger.exception(ex) raise UnknownError(str(ex)) - def filter_files(self, filenames): - """ - Gets a list of summary files. - - Args: - filenames (list[str]): File name list, like [filename1, filename2]. - - Returns: - list[str], filename list. - """ - return list(filter( - lambda filename: (re.search(r'summary\.\d+', filename) - and filename.endswith("_explain")), filenames)) - @staticmethod def _event_decode(event_str): """ @@ -153,9 +130,9 @@ class _ExplainParser(_SummaryParser): logger.debug("Deserialize event string completed.") fields = { - 'sample_id': PluginNameEnum.SAMPLE_ID, - 'benchmark': PluginNameEnum.BENCHMARK, - 'metadata': PluginNameEnum.METADATA + 'sample_id': ExplainFieldsEnum.SAMPLE_ID, + 'benchmark': ExplainFieldsEnum.BENCHMARK, + 'metadata': ExplainFieldsEnum.METADATA } tensor_event_value = getattr(event, 'explain') @@ -163,19 +140,19 @@ class _ExplainParser(_SummaryParser): field_list = [] tensor_value_list = [] for field in fields: - if not getattr(tensor_event_value, field): + if not getattr(tensor_event_value, field, False): continue - if PluginNameEnum.METADATA.value == field and not tensor_event_value.metadata.label: + if ExplainFieldsEnum.METADATA.value == field and not tensor_event_value.metadata.label: continue tensor_value = None - if field == PluginNameEnum.SAMPLE_ID.value: - tensor_value = _ExplainParser._add_image_data(tensor_event_value) - elif field == PluginNameEnum.BENCHMARK.value: - tensor_value = _ExplainParser._add_benchmark(tensor_event_value) - elif field == PluginNameEnum.METADATA.value: - tensor_value = _ExplainParser._add_metadata(tensor_event_value) + if field == ExplainFieldsEnum.SAMPLE_ID.value: + tensor_value = ExplainParser._add_image_data(tensor_event_value) + elif field == ExplainFieldsEnum.BENCHMARK.value: + tensor_value = ExplainParser._add_benchmark(tensor_event_value) + elif field == ExplainFieldsEnum.METADATA.value: + tensor_value = ExplainParser._add_metadata(tensor_event_value) logger.debug("Event generated, label is %s, step is %s.", field, event.step) field_list.append(field) tensor_value_list.append(tensor_value) @@ -189,8 +166,26 @@ class _ExplainParser(_SummaryParser): Args: tensor_event_value: the object of Explain message """ - image_data = ImageDataContainer(tensor_event_value) - return image_data + inference = InferfenceContainer( + ground_truth_prob=tensor_event_value.inference.ground_truth_prob, + ground_truth_prob_sd=tensor_event_value.inference.ground_truth_prob_sd, + ground_truth_prob_itl95_low=tensor_event_value.inference.ground_truth_prob_itl95_low, + ground_truth_prob_itl95_hi=tensor_event_value.inference.ground_truth_prob_itl95_hi, + predicted_label=tensor_event_value.inference.predicted_label, + predicted_prob=tensor_event_value.inference.predicted_prob, + predicted_prob_sd=tensor_event_value.inference.predicted_prob_sd, + predicted_prob_itl95_low=tensor_event_value.inference.predicted_prob_itl95_low, + predicted_prob_itl95_hi=tensor_event_value.inference.predicted_prob_itl95_hi + ) + sample_data = SampleContainer( + sample_id=tensor_event_value.sample_id, + image_path=tensor_event_value.image_path, + ground_truth_label=tensor_event_value.ground_truth_label, + inference=inference, + explanation=tensor_event_value.explanation, + status=tensor_event_value.status + ) + return sample_data @staticmethod def _add_benchmark(tensor_event_value): diff --git a/tests/ut/explainer/encapsulator/mock_explain_manager.py b/tests/ut/explainer/encapsulator/mock_explain_manager.py index 2c705088..624beb1e 100644 --- a/tests/ut/explainer/encapsulator/mock_explain_manager.py +++ b/tests/ut/explainer/encapsulator/mock_explain_manager.py @@ -25,7 +25,7 @@ class MockExplainJob: self.create_time = datetime.timestamp( datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT)) - self.latest_update_time = self.create_time + self.update_time = self.create_time self.sample_count = 1999 self.min_confidence = 0.5 self.explainers = ["Gradient"]