You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

symbol_metric.py 997 B

12345678910111213141516171819202122232425262728
  1. from typing import Optional, Sequence
  2. from .base_metric import BaseMetric
  3. class SymbolMetric(BaseMetric):
  4. def __init__(self, prefix: Optional[str] = None) -> None:
  5. super().__init__(prefix)
  6. def process(self, data_samples: Sequence[dict]) -> None:
  7. pred_pseudo_label = data_samples.pred_pseudo_label
  8. gt_pseudo_label = data_samples.gt_pseudo_label
  9. if not len(pred_pseudo_label) == len(gt_pseudo_label):
  10. raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
  11. for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
  12. correct_num = 0
  13. for pred_symbol, symbol in zip(pred_z, z):
  14. if pred_symbol == symbol:
  15. correct_num += 1
  16. self.results.append(correct_num / len(z))
  17. def compute_metrics(self, results: list) -> dict:
  18. metrics = dict()
  19. metrics["character_accuracy"] = sum(results) / len(results)
  20. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.