|
- import numpy as np
- import random
-
-
- class NormalSampler:
- def __init__(self, num_of_data=1000, shuffle=False):
- self._num_of_data = num_of_data
- self._data = list(range(num_of_data))
- if shuffle:
- random.shuffle(self._data)
- self.shuffle = shuffle
- self._index = 0
- self.need_reinitialize = False
-
- def __iter__(self):
- if self.need_reinitialize:
- self._index = 0
- if self.shuffle:
- random.shuffle(self._data)
- else:
- self.need_reinitialize = True
-
- return self
-
- def __next__(self):
- if self._index >= self._num_of_data:
- raise StopIteration
- _data = self._data[self._index]
- self._index += 1
- return _data
-
- def __len__(self):
- return self._num_of_data
-
-
- class NormalBatchSampler:
- def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
- # Since collections.abc.Iterable does not check for `__getitem__`, which
- # is one way for an object to be an iterable, we don't do an `isinstance`
- # check here.
- if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
- batch_size <= 0:
- raise ValueError("batch_size should be a positive integer value, "
- "but got batch_size={}".format(batch_size))
- if not isinstance(drop_last, bool):
- raise ValueError("drop_last should be a boolean value, but got "
- "drop_last={}".format(drop_last))
- self.sampler = sampler
- self.batch_size = batch_size
- self.drop_last = drop_last
-
- def __iter__(self):
- batch = []
- for idx in self.sampler:
- batch.append(idx)
- if len(batch) == self.batch_size:
- yield batch
- batch = []
- if len(batch) > 0 and not self.drop_last:
- yield batch
-
- def __len__(self) -> int:
- if self.drop_last:
- return len(self.sampler) // self.batch_size
- else:
- return (len(self.sampler) + self.batch_size - 1) // self.batch_size
-
-
- class RandomDataset:
- def __init__(self, num_data=10):
- self.data = np.random.rand(num_data)
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, item):
- return self.data[item]
-
-
|