|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import logging
- from abc import ABCMeta, abstractmethod
- from typing import Any, List, Optional, Sequence
-
- from ..utils import print_log
-
-
- class BaseMetric(metaclass=ABCMeta):
- """Base class for a metric.
-
- The metric first processes each batch of data_samples and predictions,
- and appends the processed results to the results list. Then it
- collects all results together from all ranks if distributed training
- is used. Finally, it computes the metrics of the entire dataset.
-
- Args:
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Default: None
- """
-
- def __init__(self,
- prefix: Optional[str] = None,) -> None:
- self.results: List[Any] = []
- self.prefix = prefix or self.default_prefix
-
- @abstractmethod
- def process(self, data_samples: Sequence[dict]) -> None:
- """Process one batch of data samples and predictions. The processed
- results should be stored in ``self.results``, which will be used to
- compute the metrics when all batches have been processed.
-
- Args:
- data_samples (Sequence[dict]): A batch of outputs from
- the model.
- """
-
- @abstractmethod
- def compute_metrics(self, results: list) -> dict:
- """Compute the metrics from processed results.
-
- Args:
- results (list): The processed results of each batch.
-
- Returns:
- dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
-
- def evaluate(self) -> dict:
- """Evaluate the model performance of the whole dataset after processing
- all batches.
-
- Args:
- size (int): Length of the entire validation dataset. When batch
- size > 1, the dataloader may pad some data samples to make
- sure all ranks have the same length of dataset slice. The
- ``collect_results`` function will drop the padded data based on
- this size.
-
- Returns:
- dict: Evaluation metrics dict on the val dataset. The keys are the
- names of the metrics, and the values are corresponding results.
- """
- if len(self.results) == 0:
- print_log(
- f'{self.__class__.__name__} got empty `self.results`. Please '
- 'ensure that the processed results are properly added into '
- '`self.results` in `process` method.',
- logger='current',
- level=logging.WARNING)
-
- metrics = self.compute_metrics(self.results)
- # Add prefix to metric names
- if self.prefix:
- metrics = {
- '/'.join((self.prefix, k)): v
- for k, v in metrics.items()
- }
-
- # reset the results list
- self.results.clear()
- return metrics
|