You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

symbol_accuracy.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. """
  2. This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. import numpy as np
  6. from ..structures import ListData
  7. from .base_metric import BaseMetric
  8. class SymbolAccuracy(BaseMetric):
  9. """
  10. A metrics class for evaluating symbol-level accuracy.
  11. This class is designed to assess the accuracy of symbol prediction. Symbol accuracy
  12. is calculated by comparing predicted presudo labels and their ground truth.
  13. Parameters
  14. ----------
  15. prefix : str, optional
  16. The prefix that will be added to the metrics names to disambiguate homonymous
  17. metrics of different tasks. Inherits from BaseMetric. Defaults to None.
  18. """
  19. def process(self, data_examples: ListData) -> None:
  20. """
  21. Processes a batch of data examples.
  22. This method takes in a batch of data examples, each containing a list of predicted
  23. pseudo-labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It
  24. calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol
  25. count and total symbol count is appended to ``self.results``.
  26. Parameters
  27. ----------
  28. data_examples : ListData
  29. A batch of data examples, each containing:
  30. - ``pred_pseudo_label``: List of predicted pseudo-labels.
  31. - ``gt_pseudo_label``: List of ground truth pseudo-labels.
  32. Raises
  33. ------
  34. ValueError
  35. If the lengths of predicted and ground truth symbol lists are not equal.
  36. """
  37. pred_pseudo_label_list = data_examples.flatten("pred_pseudo_label")
  38. gt_pseudo_label_list = data_examples.flatten("gt_pseudo_label")
  39. if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list):
  40. raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
  41. correct_num = np.sum(np.array(pred_pseudo_label_list) == np.array(gt_pseudo_label_list))
  42. self.results.append((correct_num, len(pred_pseudo_label_list)))
  43. def compute_metrics(self) -> dict:
  44. """
  45. Compute the symbol accuracy metrics from ``self.results``. It calculates the
  46. percentage of correctly predicted pseudo-labels over all pseudo-labels.
  47. Returns
  48. -------
  49. dict
  50. A dictionary containing the computed metrics. It includes the key
  51. 'character_accuracy' which maps to the calculated symbol-level accuracy,
  52. represented as a float between 0 and 1.
  53. """
  54. results = self.results
  55. metrics = dict()
  56. metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results)
  57. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.