You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 963 B

3 months ago
3 months ago
3 months ago
123456789101112131415161718192021222324252627
  1. from typing import Optional
  2. from ablkit.reasoning import KBBase
  3. from ablkit.data import BaseMetric, ListData
  4. class BDDReasoningMetric(BaseMetric):
  5. def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
  6. super().__init__(prefix)
  7. self.kb = kb
  8. def process(self, data_examples: ListData) -> None:
  9. pred_pseudo_label_list = data_examples.pred_pseudo_label
  10. y_list = data_examples.Y
  11. x_list = data_examples.X
  12. for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
  13. pred_y = self.kb.logic_forward(
  14. pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()
  15. )
  16. for py, yy in zip(pred_y, y):
  17. self.results.append(int(py == yy))
  18. def compute_metrics(self) -> dict:
  19. results = self.results
  20. metrics = dict()
  21. metrics["reasoning_accuracy"] = sum(results) / len(results)
  22. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.