Browse Source

- batch with multiprocessing

tags/v0.3.1^2
yunfan 6 years ago
parent
commit
03f49c8264
3 changed files with 57 additions and 28 deletions
  1. +45
    -10
      fastNLP/core/batch.py
  2. +4
    -8
      fastNLP/core/trainer.py
  3. +8
    -10
      fastNLP/io/dataset_loader.py

+ 45
- 10
fastNLP/core/batch.py View File

@@ -2,7 +2,7 @@ import numpy as np
import torch

from fastNLP.core.sampler import RandomSampler
import torch.multiprocessing as mp

class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.
@@ -29,15 +29,9 @@ class Batch(object):
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None

def __iter__(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
return self

def __next__(self):
def fetch_one(self):
if self.curidx >= len(self.idx_list):
raise StopIteration
return None
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_x, batch_y = {}, {}
@@ -56,9 +50,15 @@ class Batch(object):
batch_x[field_name] = batch

self.curidx = endidx

return batch_x, batch_y

def __iter__(self):
"""
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process
:return:
"""
return run_batch_iter(self)

def __len__(self):
return self.num_batches

@@ -75,3 +75,38 @@ def to_tensor(batch, dtype):
except:
pass
return batch


def run_fetch(batch, q):
batch.idx_list = batch.sampler(batch.dataset)
batch.curidx = 0
batch.lengths = batch.dataset.get_length()
# print('start fetch')
while 1:
res = batch.fetch_one()
# print('fetch one')
q.put(res)
if res is None:
# print('fetch done, waiting processing')
q.join()
break
# print('fetch exit')


def run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10)
fetch_p = mp.Process(target=run_fetch, args=(batch, q))
fetch_p.daemon = True
fetch_p.start()
# print('fork fetch process')
while 1:
res = q.get()
q.task_done()
# print('get fetched')
if res is None:
break
yield res
fetch_p.terminate()
fetch_p.join()
# print('iter done')


+ 4
- 8
fastNLP/core/trainer.py View File

@@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature
class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0,
use_tqdm=True, use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
@@ -127,8 +127,6 @@ class Trainer(object):
self.best_dev_perf = None
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)

if isinstance(optimizer, torch.optim.Optimizer):
@@ -249,9 +247,7 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
keep_process=True)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
@@ -261,7 +257,7 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
non_blocking=self.use_cuda) # pin_memory, use non_blockling.
prediction = self._data_forward(self.model, batch_x)

# edit prediction


+ 8
- 10
fastNLP/io/dataset_loader.py View File

@@ -876,7 +876,7 @@ class ConllPOSReader(object):


class ConllxDataLoader(object):
def load(self, path, return_dataset=False):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -894,15 +894,13 @@ class ConllxDataLoader(object):
data = [self.get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))

if return_dataset is True:
ds = DataSet()
for example in data_list:
ds.append(Instance(words=example[0],
pos_tags=example[1],
heads=example[2],
labels=example[3]))
data_list = ds
return data_list
ds = DataSet()
for example in data_list:
ds.append(Instance(words=example[0],
pos_tags=example[1],
heads=example[2],
labels=example[3]))
return ds

def get_one(self, sample):
sample = list(map(list, zip(*sample)))


Loading…
Cancel
Save