|
|
@@ -1,19 +1,17 @@ |
|
|
|
from typing import Optional, Sequence |
|
|
|
|
|
|
|
from ..reasoning import BaseKB |
|
|
|
from .base_metric import BaseMetric |
|
|
|
|
|
|
|
|
|
|
|
class SemanticsMetric(BaseMetric): |
|
|
|
def __init__(self, prefix: Optional[str] = None) -> None: |
|
|
|
def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: |
|
|
|
super().__init__(prefix) |
|
|
|
self.kb = kb |
|
|
|
|
|
|
|
def process(self, data_samples: Sequence[dict]) -> None: |
|
|
|
pred_pseudo_label = data_samples["pred_pseudo_label"] |
|
|
|
gt_Y = data_samples["Y"] |
|
|
|
logic_forward = data_samples["logic_forward"] |
|
|
|
|
|
|
|
for pred_z, y in zip(pred_pseudo_label, gt_Y): |
|
|
|
if logic_forward(pred_z) == y: |
|
|
|
for data_sample in data_samples: |
|
|
|
if self.kb.entail(data_sample, data_sample["Y"][0]): |
|
|
|
self.results.append(1) |
|
|
|
else: |
|
|
|
self.results.append(0) |
|
|
|