From 03f49c8264cf3a3c1b4912afcad2b7b11bb985a0 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 19 Jan 2019 19:44:32 +0800 Subject: [PATCH] - batch with multiprocessing --- fastNLP/core/batch.py | 55 +++++++++++++++++++++++++++++------- fastNLP/core/trainer.py | 12 +++----- fastNLP/io/dataset_loader.py | 18 ++++++------ 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index d4fcbf23..3faab8c0 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -2,7 +2,7 @@ import numpy as np import torch from fastNLP.core.sampler import RandomSampler - +import torch.multiprocessing as mp class Batch(object): """Batch is an iterable object which iterates over mini-batches. @@ -29,15 +29,9 @@ class Batch(object): self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) self.cur_batch_indices = None - def __iter__(self): - self.idx_list = self.sampler(self.dataset) - self.curidx = 0 - self.lengths = self.dataset.get_length() - return self - - def __next__(self): + def fetch_one(self): if self.curidx >= len(self.idx_list): - raise StopIteration + return None else: endidx = min(self.curidx + self.batch_size, len(self.idx_list)) batch_x, batch_y = {}, {} @@ -56,9 +50,15 @@ class Batch(object): batch_x[field_name] = batch self.curidx = endidx - return batch_x, batch_y + def __iter__(self): + """ + Iterate on dataset, fetch batch data. Fetch process don't block the iterate process + :return: + """ + return run_batch_iter(self) + def __len__(self): return self.num_batches @@ -75,3 +75,38 @@ def to_tensor(batch, dtype): except: pass return batch + + +def run_fetch(batch, q): + batch.idx_list = batch.sampler(batch.dataset) + batch.curidx = 0 + batch.lengths = batch.dataset.get_length() + # print('start fetch') + while 1: + res = batch.fetch_one() + # print('fetch one') + q.put(res) + if res is None: + # print('fetch done, waiting processing') + q.join() + break + # print('fetch exit') + + +def run_batch_iter(batch): + q = mp.JoinableQueue(maxsize=10) + fetch_p = mp.Process(target=run_fetch, args=(batch, q)) + fetch_p.daemon = True + fetch_p.start() + # print('fork fetch process') + while 1: + res = q.get() + q.task_done() + # print('get fetched') + if res is None: + break + yield res + fetch_p.terminate() + fetch_p.join() + # print('iter done') + diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a5861091..faa0d0a2 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, - timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): + check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, + use_tqdm=True, use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -127,8 +127,6 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler self.num_workers = num_workers - self.pin_memory = pin_memory - self.timeout = timeout self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): @@ -249,9 +247,7 @@ class Trainer(object): len(self.train_data) % self.batch_size != 0)) * self.n_epochs with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, - num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, - keep_process=True) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -261,7 +257,7 @@ class Trainer(object): # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, - non_blocking=self.pin_memory) # pin_memory, use non_blockling. + non_blocking=self.use_cuda) # pin_memory, use non_blockling. prediction = self._data_forward(self.model, batch_x) # edit prediction diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 211d6cc9..1fcdb7d9 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -876,7 +876,7 @@ class ConllPOSReader(object): class ConllxDataLoader(object): - def load(self, path, return_dataset=False): + def load(self, path): datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -894,15 +894,13 @@ class ConllxDataLoader(object): data = [self.get_one(sample) for sample in datalist] data_list = list(filter(lambda x: x is not None, data)) - if return_dataset is True: - ds = DataSet() - for example in data_list: - ds.append(Instance(words=example[0], - pos_tags=example[1], - heads=example[2], - labels=example[3])) - data_list = ds - return data_list + ds = DataSet() + for example in data_list: + ds.append(Instance(words=example[0], + pos_tags=example[1], + heads=example[2], + labels=example[3])) + return ds def get_one(self, sample): sample = list(map(list, zip(*sample)))