|
- # Copyright (c) OpenMMLab. All rights reserved.
- import bisect
- import collections
- import copy
- import math
- from collections import defaultdict
-
- import numpy as np
- from mmcv.utils import build_from_cfg, print_log
- from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
-
- from .builder import DATASETS, PIPELINES
- from .coco import CocoDataset
-
-
- @DATASETS.register_module()
- class ConcatDataset(_ConcatDataset):
- """A wrapper of concatenated dataset.
-
- Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
- concat the group flag for image aspect ratio.
-
- Args:
- datasets (list[:obj:`Dataset`]): A list of datasets.
- separate_eval (bool): Whether to evaluate the results
- separately if it is used as validation dataset.
- Defaults to True.
- """
-
- def __init__(self, datasets, separate_eval=True):
- super(ConcatDataset, self).__init__(datasets)
- self.CLASSES = datasets[0].CLASSES
- self.separate_eval = separate_eval
- if not separate_eval:
- if any([isinstance(ds, CocoDataset) for ds in datasets]):
- raise NotImplementedError(
- 'Evaluating concatenated CocoDataset as a whole is not'
- ' supported! Please set "separate_eval=True"')
- elif len(set([type(ds) for ds in datasets])) != 1:
- raise NotImplementedError(
- 'All the datasets should have same types')
-
- if hasattr(datasets[0], 'flag'):
- flags = []
- for i in range(0, len(datasets)):
- flags.append(datasets[i].flag)
- self.flag = np.concatenate(flags)
-
- def get_cat_ids(self, idx):
- """Get category ids of concatenated dataset by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- list[int]: All categories in the image of specified index.
- """
-
- if idx < 0:
- if -idx > len(self):
- raise ValueError(
- 'absolute value of index should not exceed dataset length')
- idx = len(self) + idx
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
- return self.datasets[dataset_idx].get_cat_ids(sample_idx)
-
- def evaluate(self, results, logger=None, **kwargs):
- """Evaluate the results.
-
- Args:
- results (list[list | tuple]): Testing results of the dataset.
- logger (logging.Logger | str | None): Logger used for printing
- related information during evaluation. Default: None.
-
- Returns:
- dict[str: float]: AP results of the total dataset or each separate
- dataset if `self.separate_eval=True`.
- """
- assert len(results) == self.cumulative_sizes[-1], \
- ('Dataset and results have different sizes: '
- f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
-
- # Check whether all the datasets support evaluation
- for dataset in self.datasets:
- assert hasattr(dataset, 'evaluate'), \
- f'{type(dataset)} does not implement evaluate function'
-
- if self.separate_eval:
- dataset_idx = -1
- total_eval_results = dict()
- for size, dataset in zip(self.cumulative_sizes, self.datasets):
- start_idx = 0 if dataset_idx == -1 else \
- self.cumulative_sizes[dataset_idx]
- end_idx = self.cumulative_sizes[dataset_idx + 1]
-
- results_per_dataset = results[start_idx:end_idx]
- print_log(
- f'\nEvaluateing {dataset.ann_file} with '
- f'{len(results_per_dataset)} images now',
- logger=logger)
-
- eval_results_per_dataset = dataset.evaluate(
- results_per_dataset, logger=logger, **kwargs)
- dataset_idx += 1
- for k, v in eval_results_per_dataset.items():
- total_eval_results.update({f'{dataset_idx}_{k}': v})
-
- return total_eval_results
- elif any([isinstance(ds, CocoDataset) for ds in self.datasets]):
- raise NotImplementedError(
- 'Evaluating concatenated CocoDataset as a whole is not'
- ' supported! Please set "separate_eval=True"')
- elif len(set([type(ds) for ds in self.datasets])) != 1:
- raise NotImplementedError(
- 'All the datasets should have same types')
- else:
- original_data_infos = self.datasets[0].data_infos
- self.datasets[0].data_infos = sum(
- [dataset.data_infos for dataset in self.datasets], [])
- eval_results = self.datasets[0].evaluate(
- results, logger=logger, **kwargs)
- self.datasets[0].data_infos = original_data_infos
- return eval_results
-
-
- @DATASETS.register_module()
- class RepeatDataset:
- """A wrapper of repeated dataset.
-
- The length of repeated dataset will be `times` larger than the original
- dataset. This is useful when the data loading time is long but the dataset
- is small. Using RepeatDataset can reduce the data loading time between
- epochs.
-
- Args:
- dataset (:obj:`Dataset`): The dataset to be repeated.
- times (int): Repeat times.
- """
-
- def __init__(self, dataset, times):
- self.dataset = dataset
- self.times = times
- self.CLASSES = dataset.CLASSES
- if hasattr(self.dataset, 'flag'):
- self.flag = np.tile(self.dataset.flag, times)
-
- self._ori_len = len(self.dataset)
-
- def __getitem__(self, idx):
- return self.dataset[idx % self._ori_len]
-
- def get_cat_ids(self, idx):
- """Get category ids of repeat dataset by index.
-
- Args:
- idx (int): Index of data.
-
- Returns:
- list[int]: All categories in the image of specified index.
- """
-
- return self.dataset.get_cat_ids(idx % self._ori_len)
-
- def __len__(self):
- """Length after repetition."""
- return self.times * self._ori_len
-
-
- # Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
- @DATASETS.register_module()
- class ClassBalancedDataset:
- """A wrapper of repeated dataset with repeat factor.
-
- Suitable for training on class imbalanced datasets like LVIS. Following
- the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_,
- in each epoch, an image may appear multiple times based on its
- "repeat factor".
- The repeat factor for an image is a function of the frequency the rarest
- category labeled in that image. The "frequency of category c" in [0, 1]
- is defined by the fraction of images in the training set (without repeats)
- in which category c appears.
- The dataset needs to instantiate :func:`self.get_cat_ids` to support
- ClassBalancedDataset.
-
- The repeat factor is computed as followed.
-
- 1. For each category c, compute the fraction # of images
- that contain it: :math:`f(c)`
- 2. For each category c, compute the category-level repeat factor:
- :math:`r(c) = max(1, sqrt(t/f(c)))`
- 3. For each image I, compute the image-level repeat factor:
- :math:`r(I) = max_{c in I} r(c)`
-
- Args:
- dataset (:obj:`CustomDataset`): The dataset to be repeated.
- oversample_thr (float): frequency threshold below which data is
- repeated. For categories with ``f_c >= oversample_thr``, there is
- no oversampling. For categories with ``f_c < oversample_thr``, the
- degree of oversampling following the square-root inverse frequency
- heuristic above.
- filter_empty_gt (bool, optional): If set true, images without bounding
- boxes will not be oversampled. Otherwise, they will be categorized
- as the pure background class and involved into the oversampling.
- Default: True.
- """
-
- def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
- self.dataset = dataset
- self.oversample_thr = oversample_thr
- self.filter_empty_gt = filter_empty_gt
- self.CLASSES = dataset.CLASSES
-
- repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
- repeat_indices = []
- for dataset_idx, repeat_factor in enumerate(repeat_factors):
- repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
- self.repeat_indices = repeat_indices
-
- flags = []
- if hasattr(self.dataset, 'flag'):
- for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
- flags.extend([flag] * int(math.ceil(repeat_factor)))
- assert len(flags) == len(repeat_indices)
- self.flag = np.asarray(flags, dtype=np.uint8)
-
- def _get_repeat_factors(self, dataset, repeat_thr):
- """Get repeat factor for each images in the dataset.
-
- Args:
- dataset (:obj:`CustomDataset`): The dataset
- repeat_thr (float): The threshold of frequency. If an image
- contains the categories whose frequency below the threshold,
- it would be repeated.
-
- Returns:
- list[float]: The repeat factors for each images in the dataset.
- """
-
- # 1. For each category c, compute the fraction # of images
- # that contain it: f(c)
- category_freq = defaultdict(int)
- num_images = len(dataset)
- for idx in range(num_images):
- cat_ids = set(self.dataset.get_cat_ids(idx))
- if len(cat_ids) == 0 and not self.filter_empty_gt:
- cat_ids = set([len(self.CLASSES)])
- for cat_id in cat_ids:
- category_freq[cat_id] += 1
- for k, v in category_freq.items():
- category_freq[k] = v / num_images
-
- # 2. For each category c, compute the category-level repeat factor:
- # r(c) = max(1, sqrt(t/f(c)))
- category_repeat = {
- cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
- for cat_id, cat_freq in category_freq.items()
- }
-
- # 3. For each image I, compute the image-level repeat factor:
- # r(I) = max_{c in I} r(c)
- repeat_factors = []
- for idx in range(num_images):
- cat_ids = set(self.dataset.get_cat_ids(idx))
- if len(cat_ids) == 0 and not self.filter_empty_gt:
- cat_ids = set([len(self.CLASSES)])
- repeat_factor = 1
- if len(cat_ids) > 0:
- repeat_factor = max(
- {category_repeat[cat_id]
- for cat_id in cat_ids})
- repeat_factors.append(repeat_factor)
-
- return repeat_factors
-
- def __getitem__(self, idx):
- ori_index = self.repeat_indices[idx]
- return self.dataset[ori_index]
-
- def __len__(self):
- """Length after repetition."""
- return len(self.repeat_indices)
-
- @DATASETS.register_module()
- class AD_ClassBalancedDataset:
- def __init__(self, dataset, oversample_thr=1.0, filter_empty_gt=False):
- self.dataset = dataset
- self.oversample_thr = oversample_thr
- self.filter_empty_gt = filter_empty_gt
- self.CLASSES = dataset.CLASSES
-
- repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
- repeat_indices = []
- for dataset_idx, repeat_factor in enumerate(repeat_factors):
- repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
- self.repeat_indices = repeat_indices
-
- flags = []
- if hasattr(self.dataset, 'flag'):
- for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
- flags.extend([flag] * int(math.ceil(repeat_factor)))
- assert len(flags) == len(repeat_indices)
- self.flag = np.asarray(flags, dtype=np.uint8)
-
- def _get_repeat_factors(self, dataset, repeat_thr):
- num_images = len(dataset)
- num_ok = 0
- num_ng = 0
- repeat_factors = []
- for idx in range(num_images):
- cat_ids = set(self.dataset.get_cat_ids(idx))
- if len(cat_ids) == 0:
- num_ok += 1
- else:
- num_ng += 1
- for idx in range(num_images):
- cat_ids = set(self.dataset.get_cat_ids(idx))
- if len(cat_ids) == 0:
- repeat_factor = 1
- else:
- repeat_factor = max(1.0, num_ok/num_ng*repeat_thr)
- repeat_factors.append(repeat_factor)
-
- return repeat_factors
-
- def __getitem__(self, idx):
- ori_index = self.repeat_indices[idx]
- return self.dataset[ori_index]
-
- def __len__(self):
- """Length after repetition."""
- return len(self.repeat_indices)
-
-
- @DATASETS.register_module()
- class MultiImageMixDataset:
- """A wrapper of multiple images mixed dataset.
-
- Suitable for training on multiple images mixed data augmentation like
- mosaic and mixup. For the augmentation pipeline of mixed image data,
- the `get_indexes` method needs to be provided to obtain the image
- indexes, and you can set `skip_flags` to change the pipeline running
- process. At the same time, we provide the `dynamic_scale` parameter
- to dynamically change the output image size.
-
- Args:
- dataset (:obj:`CustomDataset`): The dataset to be mixed.
- pipeline (Sequence[dict]): Sequence of transform object or
- config dict to be composed.
- dynamic_scale (tuple[int], optional): The image scale can be changed
- dynamically. Default to None.
- skip_type_keys (list[str], optional): Sequence of type string to
- be skip pipeline. Default to None.
- """
-
- def __init__(self,
- dataset,
- pipeline,
- dynamic_scale=None,
- skip_type_keys=None):
- assert isinstance(pipeline, collections.abc.Sequence)
- if skip_type_keys is not None:
- assert all([
- isinstance(skip_type_key, str)
- for skip_type_key in skip_type_keys
- ])
- self._skip_type_keys = skip_type_keys
-
- self.pipeline = []
- self.pipeline_types = []
- for transform in pipeline:
- if isinstance(transform, dict):
- self.pipeline_types.append(transform['type'])
- transform = build_from_cfg(transform, PIPELINES)
- self.pipeline.append(transform)
- else:
- raise TypeError('pipeline must be a dict')
-
- self.dataset = dataset
- self.CLASSES = dataset.CLASSES
- if hasattr(self.dataset, 'flag'):
- self.flag = dataset.flag
- self.num_samples = len(dataset)
-
- if dynamic_scale is not None:
- assert isinstance(dynamic_scale, tuple)
- self._dynamic_scale = dynamic_scale
-
- def __len__(self):
- return self.num_samples
-
- def __getitem__(self, idx):
- results = copy.deepcopy(self.dataset[idx])
- for (transform, transform_type) in zip(self.pipeline,
- self.pipeline_types):
- if self._skip_type_keys is not None and \
- transform_type in self._skip_type_keys:
- continue
-
- if hasattr(transform, 'get_indexes'):
- indexes = transform.get_indexes(self.dataset)
- if not isinstance(indexes, collections.abc.Sequence):
- indexes = [indexes]
- mix_results = [
- copy.deepcopy(self.dataset[index]) for index in indexes
- ]
- results['mix_results'] = mix_results
-
- if self._dynamic_scale is not None:
- # Used for subsequent pipeline to automatically change
- # the output image size. E.g MixUp, Resize.
- results['scale'] = self._dynamic_scale
-
- results = transform(results)
-
- if 'mix_results' in results:
- results.pop('mix_results')
-
- return results
-
- def update_skip_type_keys(self, skip_type_keys):
- """Update skip_type_keys. It is called by an external hook.
-
- Args:
- skip_type_keys (list[str], optional): Sequence of type
- string to be skip pipeline.
- """
- assert all([
- isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
- ])
- self._skip_type_keys = skip_type_keys
-
- def update_dynamic_scale(self, dynamic_scale):
- """Update dynamic_scale. It is called by an external hook.
-
- Args:
- dynamic_scale (tuple[int]): The image scale can be
- changed dynamically.
- """
- assert isinstance(dynamic_scale, tuple)
- self._dynamic_scale = dynamic_scale
|