|
|
@@ -9,6 +9,7 @@ import atexit |
|
|
|
|
|
|
|
from .sampler import RandomSampler, Sampler |
|
|
|
import torch.multiprocessing as mp |
|
|
|
from queue import Empty, Full |
|
|
|
|
|
|
|
_python_is_exit = False |
|
|
|
|
|
|
@@ -92,7 +93,7 @@ class Batch(object): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if self.prefetch: |
|
|
|
return _run_batch_iter(self) |
|
|
|
return self._run_batch_iter(self) |
|
|
|
|
|
|
|
def batch_iter(): |
|
|
|
self.init_iter() |
|
|
@@ -120,6 +121,57 @@ class Batch(object): |
|
|
|
""" |
|
|
|
return self.cur_batch_indices |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _run_fetch(batch, q): |
|
|
|
try: |
|
|
|
global _python_is_exit |
|
|
|
batch.init_iter() |
|
|
|
# print('start fetch') |
|
|
|
while 1: |
|
|
|
res = batch.fetch_one() |
|
|
|
# print('fetch one') |
|
|
|
while 1: |
|
|
|
try: |
|
|
|
q.put(res, timeout=3) |
|
|
|
break |
|
|
|
except Full: |
|
|
|
if _python_is_exit: |
|
|
|
return |
|
|
|
if res is None: |
|
|
|
# print('fetch done, waiting processing') |
|
|
|
break |
|
|
|
# print('fetch exit') |
|
|
|
except Exception as e: |
|
|
|
q.put(e) |
|
|
|
finally: |
|
|
|
q.join() |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _run_batch_iter(batch): |
|
|
|
q = mp.JoinableQueue(maxsize=10) |
|
|
|
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) |
|
|
|
fetch_p.daemon = True |
|
|
|
fetch_p.start() |
|
|
|
# print('fork fetch process') |
|
|
|
while 1: |
|
|
|
try: |
|
|
|
res = q.get(timeout=1) |
|
|
|
q.task_done() |
|
|
|
# print('get fetched') |
|
|
|
if res is None: |
|
|
|
break |
|
|
|
elif isinstance(res, Exception): |
|
|
|
raise res |
|
|
|
yield res |
|
|
|
except Empty as e: |
|
|
|
if fetch_p.is_alive(): |
|
|
|
continue |
|
|
|
else: |
|
|
|
break |
|
|
|
fetch_p.terminate() |
|
|
|
fetch_p.join() |
|
|
|
# print('iter done') |
|
|
|
|
|
|
|
|
|
|
|
def _to_tensor(batch, dtype): |
|
|
|
try: |
|
|
@@ -131,47 +183,3 @@ def _to_tensor(batch, dtype): |
|
|
|
pass |
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
def _run_fetch(batch, q): |
|
|
|
global _python_is_exit |
|
|
|
batch.init_iter() |
|
|
|
# print('start fetch') |
|
|
|
while 1: |
|
|
|
res = batch.fetch_one() |
|
|
|
# print('fetch one') |
|
|
|
while 1: |
|
|
|
try: |
|
|
|
q.put(res, timeout=3) |
|
|
|
break |
|
|
|
except: |
|
|
|
if _python_is_exit: |
|
|
|
return |
|
|
|
if res is None: |
|
|
|
# print('fetch done, waiting processing') |
|
|
|
q.join() |
|
|
|
break |
|
|
|
# print('fetch exit') |
|
|
|
|
|
|
|
|
|
|
|
def _run_batch_iter(batch): |
|
|
|
q = mp.JoinableQueue(maxsize=10) |
|
|
|
fetch_p = mp.Process(target=_run_fetch, args=(batch, q)) |
|
|
|
fetch_p.daemon = True |
|
|
|
fetch_p.start() |
|
|
|
# print('fork fetch process') |
|
|
|
while 1: |
|
|
|
try: |
|
|
|
res = q.get(timeout=1) |
|
|
|
q.task_done() |
|
|
|
# print('get fetched') |
|
|
|
if res is None: |
|
|
|
break |
|
|
|
yield res |
|
|
|
except Exception as e: |
|
|
|
if fetch_p.is_alive(): |
|
|
|
continue |
|
|
|
else: |
|
|
|
break |
|
|
|
fetch_p.terminate() |
|
|
|
fetch_p.join() |
|
|
|
# print('iter done') |