You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_metric.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import logging
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Any, List, Optional, Sequence
  4. from ..utils import print_log
  5. class BaseMetric(metaclass=ABCMeta):
  6. """Base class for a metric.
  7. The metric first processes each batch of data_samples and predictions,
  8. and appends the processed results to the results list. Then it
  9. collects all results together from all ranks if distributed training
  10. is used. Finally, it computes the metrics of the entire dataset.
  11. Args:
  12. prefix (str, optional): The prefix that will be added in the metric
  13. names to disambiguate homonymous metrics of different evaluators.
  14. If prefix is not provided in the argument, self.default_prefix
  15. will be used instead. Default: None
  16. """
  17. def __init__(self,
  18. prefix: Optional[str] = None,) -> None:
  19. self.results: List[Any] = []
  20. self.prefix = prefix or self.default_prefix
  21. @abstractmethod
  22. def process(self, data_samples: Sequence[dict]) -> None:
  23. """Process one batch of data samples and predictions. The processed
  24. results should be stored in ``self.results``, which will be used to
  25. compute the metrics when all batches have been processed.
  26. Args:
  27. data_samples (Sequence[dict]): A batch of outputs from
  28. the model.
  29. """
  30. @abstractmethod
  31. def compute_metrics(self, results: list) -> dict:
  32. """Compute the metrics from processed results.
  33. Args:
  34. results (list): The processed results of each batch.
  35. Returns:
  36. dict: The computed metrics. The keys are the names of the metrics,
  37. and the values are corresponding results.
  38. """
  39. def evaluate(self) -> dict:
  40. """Evaluate the model performance of the whole dataset after processing
  41. all batches.
  42. Args:
  43. size (int): Length of the entire validation dataset. When batch
  44. size > 1, the dataloader may pad some data samples to make
  45. sure all ranks have the same length of dataset slice. The
  46. ``collect_results`` function will drop the padded data based on
  47. this size.
  48. Returns:
  49. dict: Evaluation metrics dict on the val dataset. The keys are the
  50. names of the metrics, and the values are corresponding results.
  51. """
  52. if len(self.results) == 0:
  53. print_log(
  54. f'{self.__class__.__name__} got empty `self.results`. Please '
  55. 'ensure that the processed results are properly added into '
  56. '`self.results` in `process` method.',
  57. logger='current',
  58. level=logging.WARNING)
  59. metrics = self.compute_metrics(self.results)
  60. # Add prefix to metric names
  61. if self.prefix:
  62. metrics = {
  63. '/'.join((self.prefix, k)): v
  64. for k, v in metrics.items()
  65. }
  66. # reset the results list
  67. self.results.clear()
  68. return metrics

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.