from typing import Optional from ablkit.reasoning import KBBase from ablkit.data import BaseMetric, ListData class BDDReasoningMetric(BaseMetric): def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: super().__init__(prefix) self.kb = kb def process(self, data_examples: ListData) -> None: pred_pseudo_label_list = data_examples.pred_pseudo_label y_list = data_examples.Y x_list = data_examples.X for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): pred_y = self.kb.logic_forward( pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () ) for py, yy in zip(pred_y, y): self.results.append(int(py == yy)) def compute_metrics(self) -> dict: results = self.results metrics = dict() metrics["reasoning_accuracy"] = sum(results) / len(results) return metrics