You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

event_parse.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """EventParser for summary event."""
  16. from collections import namedtuple, defaultdict
  17. from typing import Dict, List, Optional, Tuple
  18. from mindinsight.explainer.common.enums import PluginNameEnum
  19. from mindinsight.explainer.common.log import logger
  20. from mindinsight.utils.exceptions import UnknownError
  21. _IMAGE_DATA_TAGS = {
  22. 'sample_id': PluginNameEnum.SAMPLE_ID.value,
  23. 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value,
  24. 'inference': PluginNameEnum.INFERENCE.value,
  25. 'explanation': PluginNameEnum.EXPLANATION.value
  26. }
  27. _NUM_DIGIT = 7
  28. class EventParser:
  29. """Parser for event data."""
  30. def __init__(self, job):
  31. self._job = job
  32. self._sample_pool = {}
  33. @staticmethod
  34. def parse_metadata(metadata) -> Tuple[List, List, List]:
  35. """Parse the metadata event."""
  36. explainers = list(metadata.explain_method)
  37. metrics = list(metadata.benchmark_method)
  38. labels = list(metadata.label)
  39. return explainers, metrics, labels
  40. @staticmethod
  41. def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]:
  42. """Parse the benchmark event."""
  43. explainer_score_dict = defaultdict(list)
  44. label_score_dict = defaultdict(dict)
  45. for benchmark in benchmarks:
  46. explainer = benchmark.explain_method
  47. metric = benchmark.benchmark_method
  48. metric_score = benchmark.total_score
  49. label_score_event = benchmark.label_score
  50. explainer_score_dict[explainer].append({
  51. 'metric': metric,
  52. 'score': round(metric_score, _NUM_DIGIT)})
  53. new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric)
  54. for label, label_scores in new_label_score_dict.items():
  55. label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores
  56. return explainer_score_dict, label_score_dict
  57. def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]:
  58. """Parse the sample event."""
  59. sample_id = sample.sample_id
  60. if sample_id not in self._sample_pool:
  61. self._sample_pool[sample_id] = sample
  62. return None
  63. for tag in _IMAGE_DATA_TAGS:
  64. try:
  65. if tag == PluginNameEnum.INFERENCE.value:
  66. self._parse_inference(sample, sample_id)
  67. elif tag == PluginNameEnum.EXPLANATION.value:
  68. self._parse_explanation(sample, sample_id)
  69. else:
  70. self._parse_sample_info(sample, sample_id, tag)
  71. except UnknownError as ex:
  72. logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex))
  73. continue
  74. if EventParser._is_ready_for_display(self._sample_pool[sample_id]):
  75. return self._sample_pool[sample_id]
  76. return None
  77. def clear(self):
  78. """Clear the loaded data."""
  79. self._sample_pool.clear()
  80. @staticmethod
  81. def _is_ready_for_display(image_container: namedtuple) -> bool:
  82. """
  83. Check whether the image_container is ready for frontend display.
  84. Args:
  85. image_container (namedtuple): container consists of sample data
  86. Return:
  87. bool: whether the image_container if ready for display
  88. """
  89. required_attrs = ['image_path', 'ground_truth_label', 'inference']
  90. for attr in required_attrs:
  91. if not EventParser.is_attr_ready(image_container, attr):
  92. return False
  93. return True
  94. @staticmethod
  95. def is_attr_ready(image_container: namedtuple, attr: str) -> bool:
  96. """
  97. Check whether the given attribute is ready in image_container.
  98. Args:
  99. image_container (namedtuple): container consist of sample data
  100. attr (str): attribute to check
  101. Returns:
  102. bool, whether the attr is ready
  103. """
  104. if getattr(image_container, attr, False):
  105. return True
  106. return False
  107. @staticmethod
  108. def _score_event_to_dict(label_score_event, metric):
  109. """Transfer metric scores per label to pre-defined structure."""
  110. new_label_score_dict = defaultdict(list)
  111. for label_id, label_score in enumerate(label_score_event):
  112. new_label_score_dict[label_id].append({
  113. 'metric': metric,
  114. 'score': round(label_score, _NUM_DIGIT),
  115. })
  116. return new_label_score_dict
  117. def _parse_inference(self, event, sample_id):
  118. """Parse the inference event."""
  119. self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob)
  120. self._sample_pool[sample_id].inference.ground_truth_prob_sd.extend(event.inference.ground_truth_prob_sd)
  121. self._sample_pool[sample_id].inference.ground_truth_prob_itl95_low.\
  122. extend(event.inference.ground_truth_prob_itl95_low)
  123. self._sample_pool[sample_id].inference.ground_truth_prob_itl95_hi.\
  124. extend(event.inference.ground_truth_prob_itl95_hi)
  125. self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label)
  126. self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob)
  127. self._sample_pool[sample_id].inference.predicted_prob_sd.extend(event.inference.predicted_prob_sd)
  128. self._sample_pool[sample_id].inference.predicted_prob_itl95_low.extend(event.inference.predicted_prob_itl95_low)
  129. self._sample_pool[sample_id].inference.predicted_prob_itl95_hi.extend(event.inference.predicted_prob_itl95_hi)
  130. def _parse_explanation(self, event, sample_id):
  131. """Parse the explanation event."""
  132. if event.explanation:
  133. for explanation_item in event.explanation:
  134. new_explanation = self._sample_pool[sample_id].explanation.add()
  135. new_explanation.explain_method = explanation_item.explain_method
  136. new_explanation.label = explanation_item.label
  137. new_explanation.heatmap_path = explanation_item.heatmap_path
  138. def _parse_sample_info(self, event, sample_id, tag):
  139. """Parse the event containing image info."""
  140. if not getattr(self._sample_pool[sample_id], tag):
  141. setattr(self._sample_pool[sample_id], tag, getattr(event, tag))