You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_metric.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. This module contains the base class used for evaluation.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. import logging
  6. from abc import ABCMeta, abstractmethod
  7. from typing import Any, List, Optional
  8. from ...utils import print_log
  9. from ..structures import ListData
  10. class BaseMetric(metaclass=ABCMeta):
  11. """
  12. Base class for a metrics.
  13. The metrics first processes each batch of data_examples and appends the processed
  14. results to the results list. Then, it computes the metrics of the entire dataset.
  15. Parameters
  16. ----------
  17. prefix : str, optional
  18. The prefix that will be added in the metrics names to disambiguate homonymous
  19. metrics of different tasks. If prefix is not provided in the argument,
  20. self.default_prefix will be used instead. Defaults to None.
  21. """
  22. def __init__(
  23. self,
  24. prefix: Optional[str] = None,
  25. ) -> None:
  26. self.default_prefix = ""
  27. self.results: List[Any] = []
  28. self.prefix = prefix or self.default_prefix
  29. @abstractmethod
  30. def process(self, data_examples: ListData) -> None:
  31. """
  32. Process one batch of data examples. The processed results should be stored
  33. in ``self.results``, which will be used to compute the metrics when all
  34. batches have been processed.
  35. Parameters
  36. ----------
  37. data_examples : ListData
  38. A batch of data examples.
  39. """
  40. @abstractmethod
  41. def compute_metrics(self) -> dict:
  42. """
  43. Compute the metrics from processed results.
  44. Returns
  45. -------
  46. dict
  47. The computed metrics. The keys are the names of the metrics,
  48. and the values are the corresponding results.
  49. """
  50. def evaluate(self) -> dict:
  51. """
  52. Evaluate the model performance of the whole dataset after processing
  53. all batches.
  54. Returns
  55. -------
  56. dict
  57. Evaluation metrics dict on the val dataset. The keys are the
  58. names of the metrics, and the values are the corresponding results.
  59. """
  60. if len(self.results) == 0:
  61. print_log(
  62. f"{self.__class__.__name__} got empty `self.results`. Please "
  63. "ensure that the processed results are properly added into "
  64. "`self.results` in `process` method.",
  65. logger="current",
  66. level=logging.WARNING,
  67. )
  68. metrics = self.compute_metrics()
  69. # Add prefix to metrics names
  70. if self.prefix:
  71. metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()}
  72. # reset the results list
  73. self.results.clear()
  74. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.