diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 09eb238..3ddb24c 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,19 +1,17 @@ from typing import Optional, Sequence +from ..reasoning import BaseKB from .base_metric import BaseMetric class SemanticsMetric(BaseMetric): - def __init__(self, prefix: Optional[str] = None) -> None: + def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) + self.kb = kb def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples["pred_pseudo_label"] - gt_Y = data_samples["Y"] - logic_forward = data_samples["logic_forward"] - - for pred_z, y in zip(pred_pseudo_label, gt_Y): - if logic_forward(pred_z) == y: + for data_sample in data_samples: + if self.kb.entail(data_sample, data_sample["Y"][0]): self.results.append(1) else: self.results.append(0) diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index e133381..c2d7938 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Optional, Sequence from .base_metric import BaseMetric