# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ExplainLoader.""" import os import re from collections import defaultdict from datetime import datetime from typing import Dict, Iterable, List, Optional, Union from mindinsight.explainer.common.enums import ExplainFieldsEnum from mindinsight.explainer.common.log import logger from mindinsight.explainer.manager.explain_parser import ExplainParser from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.utils.exceptions import ParamValueError, UnknownError _NUM_DIGITS = 6 _EXPLAIN_FIELD_NAMES = [ ExplainFieldsEnum.SAMPLE_ID, ExplainFieldsEnum.BENCHMARK, ExplainFieldsEnum.METADATA, ] _SAMPLE_FIELD_NAMES = [ ExplainFieldsEnum.GROUND_TRUTH_LABEL, ExplainFieldsEnum.INFERENCE, ExplainFieldsEnum.EXPLANATION, ] def _round(score): """Take round of a number to given precision.""" return round(score, _NUM_DIGITS) class ExplainLoader: """ExplainLoader which manage the record in the summary file.""" def __init__(self, loader_id: str, summary_dir: str): self._parser = ExplainParser(summary_dir) self._loader_info = { 'loader_id': loader_id, 'summary_dir': summary_dir, 'create_time': os.stat(summary_dir).st_ctime, 'update_time': os.stat(summary_dir).st_mtime, 'query_time': os.stat(summary_dir).st_ctime, 'uncertainty_enabled': False, } self._samples = defaultdict(dict) self._metadata = {'explainers': [], 'metrics': [], 'labels': []} self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} @property def all_classes(self) -> List[Dict]: """ Return a list of detailed label information, including label id, label name and sample count of each label. Returns: list[dict], a list of dict, each dict contains: - id (int): label id - label (str): label name - sample_count (int): number of samples for each label """ sample_count_per_label = defaultdict(int) samples_copy = self._samples.copy() for sample in samples_copy.values(): if sample.get('image', False) and sample.get('ground_truth_label', False): for label in sample['ground_truth_label']: sample_count_per_label[label] += 1 all_classes_return = [] for label_id, label_name in enumerate(self._metadata['labels']): single_info = { 'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id] } all_classes_return.append(single_info) return all_classes_return @property def query_time(self) -> float: """Return query timestamp of explain loader.""" return self._loader_info['query_time'] @query_time.setter def query_time(self, new_time: Union[datetime, float]): """ Update the query_time timestamp manually. Args: new_time (datetime.datetime or float): Updated query_time for the explain loader. """ if isinstance(new_time, datetime): self._loader_info['query_time'] = new_time.timestamp() elif isinstance(new_time, float): self._loader_info['query_time'] = new_time else: raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' .format(type(new_time))) @property def create_time(self) -> float: """Return the create timestamp of summary file.""" return self._loader_info['create_time'] @create_time.setter def create_time(self, new_time: Union[datetime, float]): """ Update the create_time manually Args: new_time (datetime.datetime or float): Updated create_time of summary_file. """ if isinstance(new_time, datetime): self._loader_info['create_time'] = new_time.timestamp() elif isinstance(new_time, float): self._loader_info['create_time'] = new_time else: raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' .format(type(new_time))) @property def explainers(self) -> List[str]: """Return a list of explainer names recorded in the summary file.""" return self._metadata['explainers'] @property def explainer_scores(self) -> List[Dict]: """ Return evaluation results for every explainer. Returns: list[dict], A list of evaluation results of each explainer. Each item contains: - explainer (str): Name of evaluated explainer. - evaluations (list[dict]): A list of evlauation results by different metrics. - class_scores (list[dict]): A list of evaluation results on different labels. Each item in the evaluations contains: - metric (str): name of metric method - score (float): evaluation result Each item in the class_scores contains: - label (str): Name of label - evaluations (list[dict]): A list of evalution results on different labels by different metrics. Each item in evaluations contains: - metric (str): Name of metric method - score (float): Evaluation scores of explainer on specific label by the metric. """ explainer_scores = [] for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items(): metric_scores = [{'metric': metric, 'score': _round(score)} for metric, score in explainer_score_on_metric.items()] label_scores = [] for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items(): score_of_single_label = { 'label': self._metadata['labels'][label], 'evaluations': [ {'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items() ], } label_scores.append(score_of_single_label) explainer_scores.append({ 'explainer': explainer, 'evaluations': metric_scores, 'class_scores': label_scores, }) return explainer_scores @property def labels(self) -> List[str]: """Return the label recorded in the summary.""" return self._metadata['labels'] @property def metrics(self) -> List[str]: """Return a list of metric names recorded in the summary file.""" return self._metadata['metrics'] @property def min_confidence(self) -> Optional[float]: """Return minimum confidence used to filter the predicted labels.""" return None @property def sample_count(self) -> int: """ Return total number of samples in the loader. Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in cache), the returned count only takes the available samples into account. Return: int, total number of available samples in the loading job. """ sample_count = 0 samples_copy = self._samples.copy() for sample in samples_copy.values(): if sample.get('image', False) and sample.get('ground_truth_label', False): sample_count += 1 return sample_count @property def samples(self) -> List[Dict]: """Return the information of all samples in the job.""" return self.get_all_samples() @property def train_id(self) -> str: """Return ID of explain loader.""" return self._loader_info['loader_id'] @property def uncertainty_enabled(self): """Whethter uncertainty is enabled.""" return self._loader_info['uncertainty_enabled'] @property def update_time(self) -> float: """Return latest modification timestamp of summary file.""" return self._loader_info['update_time'] @update_time.setter def update_time(self, new_time: Union[datetime, float]): """ Update the update_time manually. Args: new_time stamp (datetime.datetime or float): Updated time for the summary file. """ if isinstance(new_time, datetime): self._loader_info['update_time'] = new_time.timestamp() elif isinstance(new_time, float): self._loader_info['update_time'] = new_time else: raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' .format(type(new_time))) def load(self): """Start loading data from the latest summary file to the loader.""" filenames = [] for filename in FileHandler.list_dir(self._loader_info['summary_dir']): if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)): filenames.append(filename) filenames = ExplainLoader._filter_files(filenames) if not filenames: raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._loader_info['summary_dir']) is_end = False while not is_end: is_clean, is_end, event_dict = self._parser.parse_explain(filenames) if is_clean: logger.info('Summary file in %s update, reload the data in the summary.', self._loader_info['summary_dir']) self._clear_job() if event_dict: self._import_data_from_event(event_dict) def get_all_samples(self) -> List[Dict]: """ Return a list of sample information cachced in the explain job Returns: sample_list (List[SampleObj]): a list of sample objects, each object consists of: - id (int): sample id - name (str): basename of image - labels (list[str]): list of labels - inferences list[dict]) """ returned_samples = [] samples_copy = self._samples.copy() for sample_id, sample_info in samples_copy.items(): if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): continue returned_sample = { 'id': sample_id, 'name': str(sample_id), 'image': sample_info['image'], 'labels': sample_info['ground_truth_label'], } if not ExplainLoader._is_inference_valid(sample_info): continue inferences = {} for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], sample_info['ground_truth_prob'] + sample_info['predicted_prob']): inferences[label] = { 'label': self._metadata['labels'][label], 'confidence': _round(prob), 'saliency_maps': [] } if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: for label, std, low, high in zip( sample_info['ground_truth_label'] + sample_info['predicted_label'], sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] ): inferences[label]['confidence_sd'] = _round(std) inferences[label]['confidence_itl95'] = [_round(low), _round(high)] for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): for label, heatmap_path in label_heatmap_path_dict.items(): if label in inferences: inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) returned_sample['inferences'] = list(inferences.values()) returned_samples.append(returned_sample) return returned_samples def _import_data_from_event(self, event_dict: Dict): """Parse and import data from the event data.""" if 'metadata' not in event_dict and self._is_metadata_empty(): raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.') for tag, event in event_dict.items(): if tag == ExplainFieldsEnum.METADATA.value: self._import_metadata_from_event(event.metadata) elif tag == ExplainFieldsEnum.BENCHMARK.value: self._import_benchmark_from_event(event.benchmark) elif tag == ExplainFieldsEnum.SAMPLE_ID.value: self._import_sample_from_event(event) else: logger.info('Unknown ExplainField: %s', tag) def _is_metadata_empty(self): """Check whether metadata is completely loaded first.""" if not self._metadata['labels']: return True return False def _import_metadata_from_event(self, metadata_event): """Import the metadata from event into loader.""" def take_union(existed_list, imported_data): """Take union of existed_list and imported_data.""" if isinstance(imported_data, Iterable): for sample in imported_data: if sample not in existed_list: existed_list.append(sample) take_union(self._metadata['explainers'], metadata_event.explain_method) take_union(self._metadata['metrics'], metadata_event.benchmark_method) take_union(self._metadata['labels'], metadata_event.label) def _import_benchmark_from_event(self, benchmarks): """ Parse the benchmark event. Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results w.r.t different labels. The structure of self._benchmark['explainer_score'] demonstrates below: { explainer_1: {metric_name_1: score_1, ...}, explainer_2: {metric_name_1: score_1, ...}, ... } The structure of self._benchmark['label_score'] is: { explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, ... } Args: benchmarks (benchmark_container): Parsed benchmarks data from summary file. """ explainer_score = self._benchmark['explainer_score'] label_score = self._benchmark['label_score'] for benchmark in benchmarks: explainer = benchmark.explain_method metric = benchmark.benchmark_method metric_score = benchmark.total_score label_score_event = benchmark.label_score explainer_score[explainer][metric] = metric_score new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) for label, scores_of_metric in new_label_score_dict.items(): if label not in label_score[explainer]: label_score[explainer][label] = {} label_score[explainer][label].update(scores_of_metric) def _import_sample_from_event(self, sample): """ Parse the sample event. Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored in the following structure. - ground_truth_labels (list[int]): A list of ground truth labels of the sample. - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. - predicted_labels (list[int]): A list of predicted labels from the black-box model. - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary of saliency maps. The structure of explanations demonstrates below: { explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, ... } """ if not getattr(sample, 'sample_id', None): raise ParamValueError('sample_event has no sample_id') sample_id = sample.sample_id samples_copy = self._samples.copy() if sample_id not in samples_copy: self._samples[sample_id] = { 'ground_truth_label': [], 'ground_truth_prob': [], 'ground_truth_prob_sd': [], 'ground_truth_prob_itl95_low': [], 'ground_truth_prob_itl95_hi': [], 'predicted_label': [], 'predicted_prob': [], 'predicted_prob_sd': [], 'predicted_prob_itl95_low': [], 'predicted_prob_itl95_hi': [], 'explanation': defaultdict(dict) } if sample.image_path: self._samples[sample_id]['image'] = sample.image_path for tag in _SAMPLE_FIELD_NAMES: try: if ExplainLoader._is_attr_empty(sample, tag.value): continue if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) elif tag == ExplainFieldsEnum.INFERENCE: self._import_inference_from_event(sample, sample_id) elif tag == ExplainFieldsEnum.EXPLANATION: self._import_explanation_from_event(sample, sample_id) except UnknownError as ex: logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) def _import_inference_from_event(self, event, sample_id): """Parse the inference event.""" inference = event.inference self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: self._loader_info['uncertainty_enabled'] = True def _import_explanation_from_event(self, event, sample_id): """Parse the explanation event.""" if self._samples[sample_id]['explanation'] is None: self._samples[sample_id]['explanation'] = defaultdict(dict) sample_explanation = self._samples[sample_id]['explanation'] for explanation_item in event.explanation: explainer = explanation_item.explain_method label = explanation_item.label sample_explanation[explainer][label] = explanation_item.heatmap_path def _clear_job(self): """Clear the cached data and update the time info of the loader.""" self._samples.clear() self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time']) def clear_inner_dict(outer_dict): """Clear the inner structured data of the given dict.""" for item in outer_dict.values(): item.clear() map(clear_inner_dict, [self._metadata, self._benchmark]) @staticmethod def _filter_files(filenames): """ Gets a list of summary files. Args: filenames (list[str]): File name list, like [filename1, filename2]. Returns: list[str], filename list. """ return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), filenames)) @staticmethod def _is_attr_empty(event, attr_name) -> bool: if not getattr(event, attr_name): return True for item in getattr(event, attr_name): if not isinstance(item, list) or item: return False return True @staticmethod def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool: if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']): logger.info('length of ground_truth_prob does not match the length of ground_truth_label' 'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.' 'sample_id is : %s.', len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id) return False return True @staticmethod def _is_inference_valid(sample): """ Check whether the inference data is empty or have the same length. If the probs have different length with the labels, it can be confusing when assigning each prob to label. 'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be empty, so empty prob will pass the check. """ ground_truth_len = len(sample['ground_truth_label']) for name in ['ground_truth_prob', 'ground_truth_prob_sd', 'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']: if sample[name] and len(sample[name]) != ground_truth_len: return False predicted_len = len(sample['predicted_label']) for name in ['predicted_prob', 'predicted_prob_sd', 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']: if sample[name] and len(sample[name]) != predicted_len: return False return True @staticmethod def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool: if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']): logger.info('length of predicted_probs does not match the length of predicted_labels' 'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.', len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id) return False return True @staticmethod def _score_event_to_dict(label_score_event, metric) -> Dict: """Transfer metric scores per label to pre-defined structure.""" new_label_score_dict = defaultdict(dict) for label_id, label_score in enumerate(label_score_event): new_label_score_dict[label_id][metric] = label_score return new_label_score_dict