You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoning_metric.py 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Optional
  2. from ...reasoning import KBBase
  3. from ..structures import ListData
  4. from .base_metric import BaseMetric
  5. class ReasoningMetric(BaseMetric):
  6. """
  7. A metrics class for evaluating the model performance on tasks need reasoning.
  8. This class is designed to calculate the accuracy of the reasoing results. Reasoning
  9. results are generated by first using the learning part to predict pseudo-labels
  10. and then using a knowledge base (KB) to perform logical reasoning. The reasoning results
  11. are then compared with the ground truth to calculate the accuracy.
  12. Parameters
  13. ----------
  14. kb : KBBase
  15. An instance of a knowledge base, used for logical reasoning and validation.
  16. If not provided, reasoning checks are not performed. Default to None.
  17. prefix : str, optional
  18. The prefix that will be added to the metrics names to disambiguate homonymous
  19. metrics of different tasks. Inherits from BaseMetric. Default to None.
  20. Notes
  21. -----
  22. The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`,
  23. `Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning
  24. results, and input data, respectively.
  25. """
  26. def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
  27. super().__init__(prefix)
  28. self.kb = kb
  29. def process(self, data_examples: ListData) -> None:
  30. """
  31. Process a batch of data examples.
  32. This method takes in a batch of data examples, each containing predicted pseudo-labels
  33. (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
  34. evaluates the reasoning accuracy of each example by comparing the logical reasoning
  35. result (derived using the knowledge base) of the predicted pseudo-labels against Y
  36. The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended
  37. to ``self.results``.
  38. Parameters
  39. ----------
  40. data_examples : ListData
  41. A batch of data examples.
  42. """
  43. pred_pseudo_label_list = data_examples.pred_pseudo_label
  44. y_list = data_examples.Y
  45. x_list = data_examples.X
  46. for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
  47. if self.kb._check_equal(
  48. self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y
  49. ):
  50. self.results.append(1)
  51. else:
  52. self.results.append(0)
  53. def compute_metrics(self) -> dict:
  54. """
  55. Compute the reasoning accuracy metrics from ``self.results``. It calculates the
  56. percentage of correctly reasoned examples over all examples.
  57. Returns
  58. -------
  59. dict
  60. A dictionary containing the computed metrics. It includes the key
  61. 'reasoning_accuracy' which maps to the calculated reasoning accuracy,
  62. represented as a float between 0 and 1.
  63. """
  64. results = self.results
  65. metrics = dict()
  66. metrics["reasoning_accuracy"] = sum(results) / len(results)
  67. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.