|
@@ -20,8 +20,7 @@ class Batch(object): |
|
|
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. |
|
|
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False, |
|
|
|
|
|
device='cpu'): |
|
|
|
|
|
|
|
|
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): |
|
|
self.dataset = dataset |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.batch_size = batch_size |
|
|
self.sampler = sampler |
|
|
self.sampler = sampler |
|
@@ -32,8 +31,6 @@ class Batch(object): |
|
|
self.cur_batch_indices = None |
|
|
self.cur_batch_indices = None |
|
|
self.prefetch = prefetch |
|
|
self.prefetch = prefetch |
|
|
self.lengths = 0 |
|
|
self.lengths = 0 |
|
|
if not as_numpy: |
|
|
|
|
|
self.device = device if isinstance(device, torch.device) else torch.device(device) |
|
|
|
|
|
|
|
|
|
|
|
def fetch_one(self): |
|
|
def fetch_one(self): |
|
|
if self.curidx >= len(self.idx_list): |
|
|
if self.curidx >= len(self.idx_list): |
|
@@ -50,7 +47,6 @@ class Batch(object): |
|
|
batch = field.get(indices) |
|
|
batch = field.get(indices) |
|
|
if not self.as_numpy and field.padder is not None: |
|
|
if not self.as_numpy and field.padder is not None: |
|
|
batch = to_tensor(batch, field.dtype) |
|
|
batch = to_tensor(batch, field.dtype) |
|
|
batch = batch.to(self.device) |
|
|
|
|
|
if field.is_target: |
|
|
if field.is_target: |
|
|
batch_y[field_name] = batch |
|
|
batch_y[field_name] = batch |
|
|
if field.is_input: |
|
|
if field.is_input: |
|
@@ -119,12 +115,18 @@ def run_batch_iter(batch): |
|
|
fetch_p.start() |
|
|
fetch_p.start() |
|
|
# print('fork fetch process') |
|
|
# print('fork fetch process') |
|
|
while 1: |
|
|
while 1: |
|
|
res = q.get() |
|
|
|
|
|
q.task_done() |
|
|
|
|
|
# print('get fetched') |
|
|
|
|
|
if res is None: |
|
|
|
|
|
break |
|
|
|
|
|
yield res |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
res = q.get(timeout=1) |
|
|
|
|
|
q.task_done() |
|
|
|
|
|
# print('get fetched') |
|
|
|
|
|
if res is None: |
|
|
|
|
|
break |
|
|
|
|
|
yield res |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if fetch_p.is_alive(): |
|
|
|
|
|
continue |
|
|
|
|
|
else: |
|
|
|
|
|
break |
|
|
fetch_p.terminate() |
|
|
fetch_p.terminate() |
|
|
fetch_p.join() |
|
|
fetch_p.join() |
|
|
# print('iter done') |
|
|
# print('iter done') |
|
|