|
|
@@ -188,11 +188,13 @@ class Worker(threading.Thread): |
|
|
|
|
|
|
|
|
|
|
|
class KWSDataLoader: |
|
|
|
""" |
|
|
|
dataset: the dataset reference |
|
|
|
batchsize: data batch size |
|
|
|
numworkers: no. of workers |
|
|
|
prefetch: prefetch factor |
|
|
|
""" Load and organize audio data with multiple threads |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset: the dataset reference |
|
|
|
batchsize: data batch size |
|
|
|
numworkers: no. of workers |
|
|
|
prefetch: prefetch factor |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, batchsize, numworkers, prefetch=2): |
|
|
@@ -202,7 +204,7 @@ class KWSDataLoader: |
|
|
|
self.isrun = True |
|
|
|
|
|
|
|
# data queue |
|
|
|
self.pool = queue.Queue(batchsize * prefetch) |
|
|
|
self.pool = queue.Queue(numworkers * prefetch) |
|
|
|
|
|
|
|
# initialize workers |
|
|
|
self.workerlist = [] |
|
|
@@ -270,11 +272,11 @@ class KWSDataLoader: |
|
|
|
w.stopWorker() |
|
|
|
|
|
|
|
while not self.pool.empty(): |
|
|
|
self.pool.get(block=True, timeout=0.001) |
|
|
|
self.pool.get(block=True, timeout=0.01) |
|
|
|
|
|
|
|
# wait workers terminated |
|
|
|
for w in self.workerlist: |
|
|
|
while not self.pool.empty(): |
|
|
|
self.pool.get(block=True, timeout=0.001) |
|
|
|
self.pool.get(block=True, timeout=0.01) |
|
|
|
w.join() |
|
|
|
logger.info('KWSDataLoader: All worker stopped.') |