|
- import itertools
-
- import numpy as np
- import torch
- from mmcv.runner import get_dist_info
- from torch.utils.data.sampler import Sampler
-
-
- class InfiniteGroupBatchSampler(Sampler):
- """Similar to `BatchSampler` warping a `GroupSampler. It is designed for
- iteration-based runners like `IterBasedRunner` and yields a mini-batch
- indices each time, all indices in a batch should be in the same group.
-
- The implementation logic is referred to
- https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
-
- Args:
- dataset (object): The dataset.
- batch_size (int): When model is :obj:`DistributedDataParallel`,
- it is the number of training samples on each GPU.
- When model is :obj:`DataParallel`, it is
- `num_gpus * samples_per_gpu`.
- Default : 1.
- world_size (int, optional): Number of processes participating in
- distributed training. Default: None.
- rank (int, optional): Rank of current process. Default: None.
- seed (int): Random seed. Default: 0.
- shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it
- should be noted that `shuffle` can not guarantee that you can
- generate sequential indices because it need to ensure
- that all indices in a batch is in a group. Default: True.
- """ # noqa: W605
-
- def __init__(self,
- dataset,
- batch_size=1,
- world_size=None,
- rank=None,
- seed=0,
- shuffle=True):
- _rank, _world_size = get_dist_info()
- if world_size is None:
- world_size = _world_size
- if rank is None:
- rank = _rank
- self.rank = rank
- self.world_size = world_size
- self.dataset = dataset
- self.batch_size = batch_size
- self.seed = seed if seed is not None else 0
- self.shuffle = shuffle
-
- assert hasattr(self.dataset, 'flag')
- self.flag = self.dataset.flag
- self.group_sizes = np.bincount(self.flag)
- # buffer used to save indices of each group
- self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))}
-
- self.size = len(dataset)
- self.indices = self._indices_of_rank()
-
- def _infinite_indices(self):
- """Infinitely yield a sequence of indices."""
- g = torch.Generator()
- g.manual_seed(self.seed)
- while True:
- if self.shuffle:
- yield from torch.randperm(self.size, generator=g).tolist()
-
- else:
- yield from torch.arange(self.size).tolist()
-
- def _indices_of_rank(self):
- """Slice the infinite indices by rank."""
- yield from itertools.islice(self._infinite_indices(), self.rank, None,
- self.world_size)
-
- def __iter__(self):
- # once batch size is reached, yield the indices
- for idx in self.indices:
- flag = self.flag[idx]
- group_buffer = self.buffer_per_group[flag]
- group_buffer.append(idx)
- if len(group_buffer) == self.batch_size:
- yield group_buffer[:]
- del group_buffer[:]
-
- def __len__(self):
- """Length of base dataset."""
- return self.size
-
- def set_epoch(self, epoch):
- """Not supported in `IterationBased` runner."""
- raise NotImplementedError
-
-
- class InfiniteBatchSampler(Sampler):
- """Similar to `BatchSampler` warping a `DistributedSampler. It is designed
- iteration-based runners like `IterBasedRunner` and yields a mini-batch
- indices each time.
-
- The implementation logic is referred to
- https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
-
- Args:
- dataset (object): The dataset.
- batch_size (int): When model is :obj:`DistributedDataParallel`,
- it is the number of training samples on each GPU,
- When model is :obj:`DataParallel`, it is
- `num_gpus * samples_per_gpu`.
- Default : 1.
- world_size (int, optional): Number of processes participating in
- distributed training. Default: None.
- rank (int, optional): Rank of current process. Default: None.
- seed (int): Random seed. Default: 0.
- shuffle (bool): Whether shuffle the dataset or not. Default: True.
- """ # noqa: W605
-
- def __init__(self,
- dataset,
- batch_size=1,
- world_size=None,
- rank=None,
- seed=0,
- shuffle=True):
- _rank, _world_size = get_dist_info()
- if world_size is None:
- world_size = _world_size
- if rank is None:
- rank = _rank
- self.rank = rank
- self.world_size = world_size
- self.dataset = dataset
- self.batch_size = batch_size
- self.seed = seed if seed is not None else 0
- self.shuffle = shuffle
- self.size = len(dataset)
- self.indices = self._indices_of_rank()
-
- def _infinite_indices(self):
- """Infinitely yield a sequence of indices."""
- g = torch.Generator()
- g.manual_seed(self.seed)
- while True:
- if self.shuffle:
- yield from torch.randperm(self.size, generator=g).tolist()
-
- else:
- yield from torch.arange(self.size).tolist()
-
- def _indices_of_rank(self):
- """Slice the infinite indices by rank."""
- yield from itertools.islice(self._infinite_indices(), self.rank, None,
- self.world_size)
-
- def __iter__(self):
- # once batch size is reached, yield the indices
- batch_buffer = []
- for idx in self.indices:
- batch_buffer.append(idx)
- if len(batch_buffer) == self.batch_size:
- yield batch_buffer
- batch_buffer = []
-
- def __len__(self):
- """Length of base dataset."""
- return self.size
-
- def set_epoch(self, epoch):
- """Not supported in `IterationBased` runner."""
- raise NotImplementedError
|