You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

explain_loader.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ExplainLoader."""
  16. import math
  17. import os
  18. import re
  19. import threading
  20. from collections import defaultdict
  21. from datetime import datetime
  22. from enum import Enum
  23. from typing import Dict, Iterable, List, Optional, Union
  24. from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
  25. from mindinsight.datavisual.data_access.file_handler import FileHandler
  26. from mindinsight.explainer.common.enums import ExplainFieldsEnum
  27. from mindinsight.explainer.common.log import logger
  28. from mindinsight.explainer.manager.explain_parser import ExplainParser
  29. from mindinsight.utils.exceptions import ParamValueError, UnknownError
  30. _NAN_CONSTANT = 'NaN'
  31. _NUM_DIGITS = 6
  32. _EXPLAIN_FIELD_NAMES = [
  33. ExplainFieldsEnum.SAMPLE_ID,
  34. ExplainFieldsEnum.BENCHMARK,
  35. ExplainFieldsEnum.METADATA,
  36. ]
  37. _SAMPLE_FIELD_NAMES = [
  38. ExplainFieldsEnum.GROUND_TRUTH_LABEL,
  39. ExplainFieldsEnum.INFERENCE,
  40. ExplainFieldsEnum.EXPLANATION,
  41. ExplainFieldsEnum.HIERARCHICAL_OCCLUSION
  42. ]
  43. class _LoaderStatus(Enum):
  44. STOP = 'STOP'
  45. LOADING = 'LOADING'
  46. PENDING = 'PENDING'
  47. LOADED = 'LOADED'
  48. def _round(score):
  49. """Take round of a number to given precision."""
  50. try:
  51. return round(score, _NUM_DIGITS)
  52. except TypeError:
  53. return score
  54. class ExplainLoader:
  55. """ExplainLoader which manage the record in the summary file."""
  56. def __init__(self,
  57. loader_id: str,
  58. summary_dir: str):
  59. self._parser = ExplainParser(summary_dir)
  60. self._loader_info = {
  61. 'loader_id': loader_id,
  62. 'summary_dir': summary_dir,
  63. 'create_time': os.stat(summary_dir).st_ctime,
  64. 'update_time': os.stat(summary_dir).st_mtime,
  65. 'query_time': os.stat(summary_dir).st_ctime,
  66. 'uncertainty_enabled': False,
  67. }
  68. self._samples = defaultdict(dict)
  69. self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5}
  70. self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)}
  71. self._status = _LoaderStatus.PENDING.value
  72. self._status_mutex = threading.Lock()
  73. @property
  74. def all_classes(self) -> List[Dict]:
  75. """
  76. Return a list of detailed label information, including label id, label name and sample count of each label.
  77. Returns:
  78. list[dict], a list of dict, each dict contains:
  79. - id (int): Label id.
  80. - label (str): Label name.
  81. - sample_count (int): Number of samples for each label.
  82. """
  83. sample_count_per_label = defaultdict(int)
  84. saliency_count_per_label = defaultdict(int)
  85. hoc_count_per_label = defaultdict(int)
  86. for sample in self._samples.values():
  87. if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')):
  88. for label in set(sample['ground_truth_label'] + sample['predicted_label']):
  89. sample_count_per_label[label] += 1
  90. if sample['inferences'][label]['saliency_maps']:
  91. saliency_count_per_label[label] += 1
  92. if sample['inferences'][label]['hoc_layers']:
  93. hoc_count_per_label[label] += 1
  94. all_classes_return = [{'id': label_id,
  95. 'label': label_name,
  96. 'sample_count': sample_count_per_label[label_id],
  97. 'saliency_sample_count': saliency_count_per_label[label_id],
  98. 'hoc_sample_count': hoc_count_per_label[label_id]}
  99. for label_id, label_name in enumerate(self._metadata['labels'])]
  100. return all_classes_return
  101. @property
  102. def query_time(self) -> float:
  103. """Return query timestamp of explain loader."""
  104. return self._loader_info['query_time']
  105. @query_time.setter
  106. def query_time(self, new_time: Union[datetime, float]):
  107. """
  108. Update the query_time timestamp manually.
  109. Args:
  110. new_time (datetime.datetime or float): Updated query_time for the explain loader.
  111. """
  112. if isinstance(new_time, datetime):
  113. self._loader_info['query_time'] = new_time.timestamp()
  114. elif isinstance(new_time, float):
  115. self._loader_info['query_time'] = new_time
  116. else:
  117. raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
  118. .format(type(new_time)))
  119. @property
  120. def create_time(self) -> float:
  121. """Return the create timestamp of summary file."""
  122. return self._loader_info['create_time']
  123. @create_time.setter
  124. def create_time(self, new_time: Union[datetime, float]):
  125. """
  126. Update the create_time manually
  127. Args:
  128. new_time (datetime.datetime or float): Updated create_time of summary_file.
  129. """
  130. if isinstance(new_time, datetime):
  131. self._loader_info['create_time'] = new_time.timestamp()
  132. elif isinstance(new_time, float):
  133. self._loader_info['create_time'] = new_time
  134. else:
  135. raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
  136. .format(type(new_time)))
  137. @property
  138. def explainers(self) -> List[str]:
  139. """Return a list of explainer names recorded in the summary file."""
  140. return self._metadata['explainers']
  141. @property
  142. def explainer_scores(self) -> List[Dict]:
  143. """
  144. Return evaluation results for every explainer.
  145. Returns:
  146. list[dict], A list of evaluation results of each explainer. Each item contains:
  147. - explainer (str): Name of evaluated explainer.
  148. - evaluations (list[dict]): A list of evaluation results by different metrics.
  149. - class_scores (list[dict]): A list of evaluation results on different labels.
  150. Each item in the evaluations contains:
  151. - metric (str): name of metric method
  152. - score (float): evaluation result
  153. Each item in the class_scores contains:
  154. - label (str): Name of label
  155. - evaluations (list[dict]): A list of evaluation results on different labels by different metrics.
  156. Each item in evaluations contains:
  157. - metric (str): Name of metric method
  158. - score (float): Evaluation scores of explainer on specific label by the metric.
  159. """
  160. explainer_scores = []
  161. for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items():
  162. metric_scores = [{'metric': metric, 'score': _round(score)}
  163. for metric, score in explainer_score_on_metric.items()]
  164. label_scores = []
  165. for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items():
  166. score_of_single_label = {
  167. 'label': self._metadata['labels'][label],
  168. 'evaluations': [
  169. {'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items()
  170. ],
  171. }
  172. label_scores.append(score_of_single_label)
  173. explainer_scores.append({
  174. 'explainer': explainer,
  175. 'evaluations': metric_scores,
  176. 'class_scores': label_scores,
  177. })
  178. return explainer_scores
  179. @property
  180. def labels(self) -> List[str]:
  181. """Return the label recorded in the summary."""
  182. return self._metadata['labels']
  183. @property
  184. def metrics(self) -> List[str]:
  185. """Return a list of metric names recorded in the summary file."""
  186. return self._metadata['metrics']
  187. @property
  188. def min_confidence(self) -> Optional[float]:
  189. """Return minimum confidence used to filter the predicted labels."""
  190. return self._metadata['min_confidence']
  191. @property
  192. def sample_count(self) -> int:
  193. """
  194. Return total number of samples in the loader.
  195. Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in
  196. cache), the returned count only takes the available samples into account.
  197. Return:
  198. int, total number of available samples in the loading job.
  199. """
  200. sample_count = 0
  201. for sample in self._samples.values():
  202. if sample.get('image', False):
  203. sample_count += 1
  204. return sample_count
  205. @property
  206. def samples(self) -> List[Dict]:
  207. """Return the information of all samples in the job."""
  208. return self._samples
  209. @property
  210. def train_id(self) -> str:
  211. """Return ID of explain loader."""
  212. return self._loader_info['loader_id']
  213. @property
  214. def uncertainty_enabled(self):
  215. """Whether uncertainty is enabled."""
  216. return self._loader_info['uncertainty_enabled']
  217. @property
  218. def update_time(self) -> float:
  219. """Return latest modification timestamp of summary file."""
  220. return self._loader_info['update_time']
  221. @update_time.setter
  222. def update_time(self, new_time: Union[datetime, float]):
  223. """
  224. Update the update_time manually.
  225. Args:
  226. new_time (datetime.datetime or float): Updated time for the summary file.
  227. """
  228. if isinstance(new_time, datetime):
  229. self._loader_info['update_time'] = new_time.timestamp()
  230. elif isinstance(new_time, float):
  231. self._loader_info['update_time'] = new_time
  232. else:
  233. raise TypeError('new_time should have type of datetime.datetime or float, but receive {}'
  234. .format(type(new_time)))
  235. def load(self):
  236. """Start loading data from the latest summary file to the loader."""
  237. if self.status != _LoaderStatus.LOADED.value:
  238. self.status = _LoaderStatus.LOADING.value
  239. filenames = []
  240. for filename in FileHandler.list_dir(self._loader_info['summary_dir']):
  241. if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)):
  242. filenames.append(filename)
  243. filenames = ExplainLoader._filter_files(filenames)
  244. if not filenames:
  245. raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.'
  246. % self._loader_info['summary_dir'])
  247. is_end = False
  248. while not is_end and self.status != _LoaderStatus.STOP.value:
  249. try:
  250. file_changed, is_end, event_dict = self._parser.list_events(filenames)
  251. except UnknownError:
  252. is_end = True
  253. break
  254. if file_changed:
  255. logger.info('Summary file in %s update, reload the data in the summary.',
  256. self._loader_info['summary_dir'])
  257. self._clear_job()
  258. if self.status != _LoaderStatus.STOP.value:
  259. self.status = _LoaderStatus.LOADING.value
  260. if event_dict:
  261. self._import_data_from_event(event_dict)
  262. self._reform_sample_info()
  263. if is_end:
  264. self.status = _LoaderStatus.LOADED.value
  265. @property
  266. def status(self):
  267. """Get the status of this class with lock."""
  268. with self._status_mutex:
  269. return self._status
  270. @status.setter
  271. def status(self, status):
  272. """Set the status of this class with lock."""
  273. with self._status_mutex:
  274. self._status = status
  275. def stop(self):
  276. """Stop load data."""
  277. self.status = _LoaderStatus.STOP.value
  278. def get_all_samples(self) -> List[Dict]:
  279. """
  280. Return a list of sample information cached in the explain job.
  281. Returns:
  282. sample_list (list[SampleObj]): a list of sample objects, each object consists of:
  283. - id (int): Sample id.
  284. - name (str): Basename of image.
  285. - inferences (list[dict]): List of inferences for all labels.
  286. """
  287. returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'],
  288. 'inferences': list(info['inferences'].values())} for sample_id, info in
  289. self._samples.items() if info.get('image', False)]
  290. return returned_samples
  291. def _import_data_from_event(self, event_dict: Dict):
  292. """Parse and import data from the event data."""
  293. if 'metadata' not in event_dict and self._is_metadata_empty():
  294. raise ParamValueError('metadata is incomplete, should write metadata first in the summary.')
  295. for tag, event in event_dict.items():
  296. if tag == ExplainFieldsEnum.METADATA.value:
  297. self._import_metadata_from_event(event.metadata)
  298. elif tag == ExplainFieldsEnum.BENCHMARK.value:
  299. self._import_benchmark_from_event(event.benchmark)
  300. elif tag == ExplainFieldsEnum.SAMPLE_ID.value:
  301. self._import_sample_from_event(event)
  302. else:
  303. logger.info('Unknown ExplainField: %s.', tag)
  304. def _is_metadata_empty(self):
  305. """Check whether metadata is completely loaded first."""
  306. if not self._metadata['labels']:
  307. return True
  308. return False
  309. def _import_metadata_from_event(self, metadata_event):
  310. """Import the metadata from event into loader."""
  311. def take_union(existed_list, imported_data):
  312. """Take union of existed_list and imported_data."""
  313. if isinstance(imported_data, Iterable):
  314. for sample in imported_data:
  315. if sample not in existed_list:
  316. existed_list.append(sample)
  317. take_union(self._metadata['explainers'], metadata_event.explain_method)
  318. take_union(self._metadata['metrics'], metadata_event.benchmark_method)
  319. take_union(self._metadata['labels'], metadata_event.label)
  320. def _import_benchmark_from_event(self, benchmarks):
  321. """
  322. Parse the benchmark event.
  323. Benchmark data are separated into 'explainer_score' and 'label_score'. 'explainer_score' contains overall
  324. evaluation results of each explainer by different metrics, while 'label_score' additionally divides the results
  325. w.r.t different labels.
  326. The structure of self._benchmark['explainer_score'] demonstrates below:
  327. {
  328. explainer_1: {metric_name_1: score_1, ...},
  329. explainer_2: {metric_name_1: score_1, ...},
  330. ...
  331. }
  332. The structure of self._benchmark['label_score'] is:
  333. {
  334. explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...},
  335. explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...},
  336. ...
  337. }
  338. Args:
  339. benchmarks (BenchmarkContainer): Parsed benchmarks data from summary file.
  340. """
  341. explainer_score = self._benchmark['explainer_score']
  342. label_score = self._benchmark['label_score']
  343. for benchmark in benchmarks:
  344. explainer = benchmark.explain_method
  345. metric = benchmark.benchmark_method
  346. metric_score = benchmark.total_score
  347. label_score_event = benchmark.label_score
  348. explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score
  349. new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric)
  350. for label, scores_of_metric in new_label_score_dict.items():
  351. if label not in label_score[explainer]:
  352. label_score[explainer][label] = {}
  353. label_score[explainer][label].update(scores_of_metric)
  354. def _import_sample_from_event(self, sample):
  355. """
  356. Parse the sample event.
  357. Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored
  358. in the following structure:
  359. - ground_truth_labels (list[int]): A list of ground truth labels of the sample.
  360. - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model.
  361. - predicted_labels (list[int]): A list of predicted labels from the black-box model.
  362. - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels.
  363. - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary
  364. of saliency maps. The structure of explanations demonstrates below:
  365. {
  366. explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
  367. explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
  368. ...
  369. }
  370. - hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary:
  371. {label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}],
  372. label_2:
  373. }
  374. """
  375. if getattr(sample, 'sample_id', None) is None:
  376. raise ParamValueError('sample_event has no sample_id')
  377. sample_id = sample.sample_id
  378. if sample_id not in self._samples:
  379. self._samples[sample_id] = {
  380. 'id': sample_id,
  381. 'name': str(sample_id),
  382. 'image': sample.image_path,
  383. 'ground_truth_label': [],
  384. 'predicted_label': [],
  385. 'inferences': defaultdict(dict),
  386. 'explanation': defaultdict(dict),
  387. 'hierarchical_occlusion': defaultdict(dict)
  388. }
  389. if sample.image_path:
  390. self._samples[sample_id]['image'] = sample.image_path
  391. for tag in _SAMPLE_FIELD_NAMES:
  392. if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
  393. if not self._samples[sample_id]['ground_truth_label']:
  394. self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
  395. elif tag == ExplainFieldsEnum.INFERENCE:
  396. self._import_inference_from_event(sample, sample_id)
  397. elif tag == ExplainFieldsEnum.EXPLANATION:
  398. self._import_explanation_from_event(sample, sample_id)
  399. elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION:
  400. self._import_hoc_from_event(sample, sample_id)
  401. def _reform_sample_info(self):
  402. """Reform the sample info."""
  403. for _, sample_info in self._samples.items():
  404. inferences = sample_info['inferences']
  405. res_dict = defaultdict(list)
  406. for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
  407. for label, heatmap_path in label_heatmap_path_dict.items():
  408. res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path})
  409. for label, item in inferences.items():
  410. item['saliency_maps'] = res_dict[label]
  411. for label, item in sample_info['hierarchical_occlusion'].items():
  412. inferences[label]['hoc_layers'] = item['hoc_layers']
  413. def _import_inference_from_event(self, event, sample_id):
  414. """Parse the inference event."""
  415. inference = event.inference
  416. if inference.ground_truth_prob_sd or inference.predicted_prob_sd:
  417. self._loader_info['uncertainty_enabled'] = True
  418. if not self._samples[sample_id]['predicted_label']:
  419. self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
  420. if not self._samples[sample_id]['inferences']:
  421. inferences = {}
  422. for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label),
  423. list(inference.ground_truth_prob) + list(inference.predicted_prob)):
  424. inferences[label] = {
  425. 'label': self._metadata['labels'][label],
  426. 'confidence': _round(prob),
  427. 'saliency_maps': [],
  428. 'hoc_layers': {},
  429. }
  430. if not event.ground_truth_label:
  431. inferences[label]['prediction_type'] = None
  432. else:
  433. if prob < self.min_confidence:
  434. inferences[label]['prediction_type'] = 'FN'
  435. elif label in event.ground_truth_label:
  436. inferences[label]['prediction_type'] = 'TP'
  437. else:
  438. inferences[label]['prediction_type'] = 'FP'
  439. if self._loader_info['uncertainty_enabled']:
  440. for label, std, low, high in zip(
  441. list(event.ground_truth_label) + list(inference.predicted_label),
  442. list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd),
  443. list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low),
  444. list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)):
  445. inferences[label]['confidence_sd'] = _round(std)
  446. inferences[label]['confidence_itl95'] = [_round(low), _round(high)]
  447. self._samples[sample_id]['inferences'] = inferences
  448. def _import_explanation_from_event(self, event, sample_id):
  449. """Parse the explanation event."""
  450. if self._samples[sample_id]['explanation'] is None:
  451. self._samples[sample_id]['explanation'] = defaultdict(dict)
  452. sample_explanation = self._samples[sample_id]['explanation']
  453. for explanation_item in event.explanation:
  454. explainer = explanation_item.explain_method
  455. label = explanation_item.label
  456. sample_explanation[explainer][label] = explanation_item.heatmap_path
  457. def _import_hoc_from_event(self, event, sample_id):
  458. """Parse the mango event."""
  459. sample_hoc = self._samples[sample_id]['hierarchical_occlusion']
  460. if event.hierarchical_occlusion:
  461. for hoc_item in event.hierarchical_occlusion:
  462. label = hoc_item.label
  463. sample_hoc[label] = {}
  464. sample_hoc[label]['label'] = label
  465. sample_hoc[label]['mask'] = hoc_item.mask
  466. sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence']
  467. sample_hoc[label]['hoc_layers'] = []
  468. for hoc_layer in hoc_item.layer:
  469. sample_hoc_dict = {'confidence': hoc_layer.prob}
  470. box_lst = list(hoc_layer.box)
  471. box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)]
  472. sample_hoc_dict['boxes'] = box
  473. sample_hoc[label]['hoc_layers'].append(sample_hoc_dict)
  474. def _clear_job(self):
  475. """Clear the cached data and update the time info of the loader."""
  476. self._samples.clear()
  477. self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime
  478. self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime
  479. self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time'])
  480. def clear_inner_dict(outer_dict):
  481. """Clear the inner structured data of the given dict."""
  482. for item in outer_dict.values():
  483. item.clear()
  484. map(clear_inner_dict, [self._metadata, self._benchmark])
  485. @staticmethod
  486. def _filter_files(filenames):
  487. """
  488. Gets a list of summary files.
  489. Args:
  490. filenames (list[str]): File name list, like [filename1, filename2].
  491. Returns:
  492. list[str], filename list.
  493. """
  494. return list(filter(
  495. lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), filenames))
  496. @staticmethod
  497. def _is_inference_valid(sample):
  498. """
  499. Check whether the inference data is empty or have the same length.
  500. If probs have different length with the labels, it can be confusing when assigning each prob to label.
  501. '_is_inference_valid' returns True only when the data size of match to each other. Note that prob data could be
  502. empty, so empty prob will pass the check.
  503. """
  504. ground_truth_len = len(sample['ground_truth_label'])
  505. for name in ['ground_truth_prob', 'ground_truth_prob_sd',
  506. 'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']:
  507. if sample[name] and len(sample[name]) != ground_truth_len:
  508. logger.info('Length of %s not match the ground_truth_label. Length of ground_truth_label: %d,'
  509. 'length of %s: %d', name, ground_truth_len, name, len(sample[name]))
  510. return False
  511. predicted_len = len(sample['predicted_label'])
  512. for name in ['predicted_prob', 'predicted_prob_sd',
  513. 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']:
  514. if sample[name] and len(sample[name]) != predicted_len:
  515. logger.info('Length of %s not match the predicted_labels. Length of predicted_label: %d,'
  516. 'length of %s: %d', name, predicted_len, name, len(sample[name]))
  517. return False
  518. return True
  519. @staticmethod
  520. def _score_event_to_dict(label_score_event, metric) -> Dict:
  521. """Transfer metric scores per label to pre-defined structure."""
  522. new_label_score_dict = defaultdict(dict)
  523. for label_id, label_score in enumerate(label_score_event):
  524. new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score
  525. return new_label_score_dict