|
- from typing import Optional
-
- from ...reasoning import KBBase
- from ..structures import ListData
- from .base_metric import BaseMetric
-
-
- class ReasoningMetric(BaseMetric):
- """
- A metrics class for evaluating the model performance on tasks need reasoning.
-
- This class is designed to calculate the accuracy of the reasoing results. Reasoning
- results are generated by first using the learning part to predict pseudo-labels
- and then using a knowledge base (KB) to perform logical reasoning. The reasoning results
- are then compared with the ground truth to calculate the accuracy.
-
- Parameters
- ----------
- kb : KBBase
- An instance of a knowledge base, used for logical reasoning and validation.
- If not provided, reasoning checks are not performed. Default to None.
- prefix : str, optional
- The prefix that will be added to the metrics names to disambiguate homonymous
- metrics of different tasks. Inherits from BaseMetric. Default to None.
-
- Notes
- -----
- The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`,
- `Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning
- results, and input data, respectively.
- """
-
- def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
- super().__init__(prefix)
- self.kb = kb
-
- def process(self, data_examples: ListData) -> None:
- """
- Process a batch of data examples.
-
- This method takes in a batch of data examples, each containing predicted pseudo-labels
- (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
- evaluates the reasoning accuracy of each example by comparing the logical reasoning
- result (derived using the knowledge base) of the predicted pseudo-labels against Y
- The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended
- to ``self.results``.
-
- Parameters
- ----------
- data_examples : ListData
- A batch of data examples.
- """
- pred_pseudo_label_list = data_examples.pred_pseudo_label
- y_list = data_examples.Y
- x_list = data_examples.X
- for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
- if self.kb._check_equal(
- self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y
- ):
- self.results.append(1)
- else:
- self.results.append(0)
-
- def compute_metrics(self) -> dict:
- """
- Compute the reasoning accuracy metrics from ``self.results``. It calculates the
- percentage of correctly reasoned examples over all examples.
-
- Returns
- -------
- dict
- A dictionary containing the computed metrics. It includes the key
- 'reasoning_accuracy' which maps to the calculated reasoning accuracy,
- represented as a float between 0 and 1.
- """
- results = self.results
- metrics = dict()
- metrics["reasoning_accuracy"] = sum(results) / len(results)
- return metrics
|