from typing import Optional, Sequence from ..reasoning import BaseKB from .base_metric import BaseMetric class SemanticsMetric(BaseMetric): def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: for data_sample in data_samples: if self.kb.check_equal(data_sample, data_sample["Y"][0]): self.results.append(1) else: self.results.append(0) def compute_metrics(self, results: list) -> dict: metrics = dict() metrics["semantics_accuracy"] = sum(results) / len(results) return metrics