|
|
@@ -15,24 +15,28 @@ from fastNLP.core.sampler import RandomSampler |
|
|
|
|
|
|
|
class Batch(object): |
|
|
|
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False, |
|
|
|
timeout=0.0): |
|
|
|
timeout=0.0, keep_process=False): |
|
|
|
""" |
|
|
|
Batch is an iterable object which iterates over mini-batches. |
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): |
|
|
|
# ... |
|
|
|
iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler()) |
|
|
|
for epoch in range(num_epochs): |
|
|
|
for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。 |
|
|
|
# ... |
|
|
|
|
|
|
|
:param DataSet dataset: a DataSet object |
|
|
|
:param int batch_size: the size of the batch |
|
|
|
:param Sampler sampler: a Sampler object |
|
|
|
:param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors. |
|
|
|
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. |
|
|
|
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 |
|
|
|
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 |
|
|
|
:param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 |
|
|
|
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 |
|
|
|
检测是否出现了batch产生阻塞的情况。 |
|
|
|
:param keep_process: bool. 默认为False,该参数只在多进程下有效。在多进程的情况下,反复产生batch的iterator会导致 |
|
|
|
不断创建、销毁进程,可能对速度有一定的影响。当keep_process为True时,直到Batch对象被删除之前,多进程都没有关 |
|
|
|
闭。如果设置了keep_process为True,可以通过del BatchObject来删除Batch对象并关闭进程。 |
|
|
|
""" |
|
|
|
|
|
|
|
if num_workers < 0: |
|
|
@@ -45,15 +49,24 @@ class Batch(object): |
|
|
|
self.batch_size = batch_size |
|
|
|
self.sampler = sampler |
|
|
|
self.num_workers = num_workers |
|
|
|
self.keep_process = keep_process |
|
|
|
self.pin_memory = pin_memory |
|
|
|
self.timeout = timeout |
|
|
|
self.as_numpy = as_numpy |
|
|
|
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) |
|
|
|
self.cur_batch_indices = None |
|
|
|
self._data_iterator = None |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
# TODO 现在多线程的情况下每个循环都会重新创建多进程,开销可能有点大。可以考虑直接复用iterator. |
|
|
|
return _DataLoaderIter(self) |
|
|
|
if self._data_iterator is not None: |
|
|
|
# 重新设置index_list |
|
|
|
self._data_iterator.reset() |
|
|
|
return self._data_iterator |
|
|
|
elif self.keep_process and self.num_workers>0: |
|
|
|
self._data_iterator = _DataLoaderIter(self) |
|
|
|
return self._data_iterator |
|
|
|
else: # 大多数情况是这个 |
|
|
|
return _DataLoaderIter(self) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return self.num_batches |
|
|
@@ -61,6 +74,12 @@ class Batch(object): |
|
|
|
def get_batch_indices(self): |
|
|
|
return self.cur_batch_indices |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if self.keep_process is True: |
|
|
|
del self._data_iterator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_tensor(batch, dtype): |
|
|
|
try: |
|
|
|
if dtype in (int, np.int8, np.int16, np.int32, np.int64): |
|
|
@@ -276,6 +295,7 @@ class _DataLoaderIter(object): |
|
|
|
self.num_workers = batcher.num_workers |
|
|
|
self.pin_memory = batcher.pin_memory and torch.cuda.is_available() |
|
|
|
self.timeout = batcher.timeout |
|
|
|
self.keep_process = batcher.keep_process |
|
|
|
self.done_event = threading.Event() |
|
|
|
self.curidx = 0 |
|
|
|
self.idx_list = self.sampler(self.dataset) |
|
|
@@ -335,6 +355,17 @@ class _DataLoaderIter(object): |
|
|
|
for _ in range(2 * self.num_workers): |
|
|
|
self._put_indices() |
|
|
|
|
|
|
|
def reset(self): |
|
|
|
""" |
|
|
|
重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if self.keep_process: |
|
|
|
self.curidx = 0 |
|
|
|
self.idx_list = self.sampler(self.dataset) |
|
|
|
for _ in range(2 * self.num_workers): |
|
|
|
self._put_indices() |
|
|
|
|
|
|
|
def _get_batch(self): |
|
|
|
if self.timeout > 0: |
|
|
|
try: |
|
|
@@ -366,7 +397,8 @@ class _DataLoaderIter(object): |
|
|
|
|
|
|
|
# 如果生成的数据为0了,则停止 |
|
|
|
if self.batches_outstanding == 0: |
|
|
|
self._shutdown_workers() |
|
|
|
if not self.keep_process: |
|
|
|
self._shutdown_workers() |
|
|
|
raise StopIteration |
|
|
|
|
|
|
|
while True: |
|
|
@@ -449,4 +481,4 @@ class _DataLoaderIter(object): |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if self.num_workers > 0: |
|
|
|
self._shutdown_workers() |
|
|
|
self._shutdown_workers() |