@@ -2,7 +2,7 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.sampler import RandomSampler | |||
import torch.multiprocessing as mp | |||
class Batch(object): | |||
"""Batch is an iterable object which iterates over mini-batches. | |||
@@ -29,15 +29,9 @@ class Batch(object): | |||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
return self | |||
def __next__(self): | |||
def fetch_one(self): | |||
if self.curidx >= len(self.idx_list): | |||
raise StopIteration | |||
return None | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
batch_x, batch_y = {}, {} | |||
@@ -56,9 +50,15 @@ class Batch(object): | |||
batch_x[field_name] = batch | |||
self.curidx = endidx | |||
return batch_x, batch_y | |||
def __iter__(self): | |||
""" | |||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||
:return: | |||
""" | |||
return run_batch_iter(self) | |||
def __len__(self): | |||
return self.num_batches | |||
@@ -75,3 +75,38 @@ def to_tensor(batch, dtype): | |||
except: | |||
pass | |||
return batch | |||
def run_fetch(batch, q): | |||
batch.idx_list = batch.sampler(batch.dataset) | |||
batch.curidx = 0 | |||
batch.lengths = batch.dataset.get_length() | |||
# print('start fetch') | |||
while 1: | |||
res = batch.fetch_one() | |||
# print('fetch one') | |||
q.put(res) | |||
if res is None: | |||
# print('fetch done, waiting processing') | |||
q.join() | |||
break | |||
# print('fetch exit') | |||
def run_batch_iter(batch): | |||
q = mp.JoinableQueue(maxsize=10) | |||
fetch_p = mp.Process(target=run_fetch, args=(batch, q)) | |||
fetch_p.daemon = True | |||
fetch_p.start() | |||
# print('fork fetch process') | |||
while 1: | |||
res = q.get() | |||
q.task_done() | |||
# print('get fetched') | |||
if res is None: | |||
break | |||
yield res | |||
fetch_p.terminate() | |||
fetch_p.join() | |||
# print('iter done') | |||
@@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature | |||
class Trainer(object): | |||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, | |||
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): | |||
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, | |||
use_tqdm=True, use_cuda=False, callbacks=None): | |||
""" | |||
:param DataSet train_data: the training data | |||
:param torch.nn.modules.module model: a PyTorch model | |||
@@ -127,8 +127,6 @@ class Trainer(object): | |||
self.best_dev_perf = None | |||
self.sampler = sampler | |||
self.num_workers = num_workers | |||
self.pin_memory = pin_memory | |||
self.timeout = timeout | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
@@ -249,9 +247,7 @@ class Trainer(object): | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, | |||
keep_process=True) | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -261,7 +257,7 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.before_batch(batch_x, batch_y, indices) | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | |||
non_blocking=self.pin_memory) # pin_memory, use non_blockling. | |||
non_blocking=self.use_cuda) # pin_memory, use non_blockling. | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
@@ -876,7 +876,7 @@ class ConllPOSReader(object): | |||
class ConllxDataLoader(object): | |||
def load(self, path, return_dataset=False): | |||
def load(self, path): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
@@ -894,15 +894,13 @@ class ConllxDataLoader(object): | |||
data = [self.get_one(sample) for sample in datalist] | |||
data_list = list(filter(lambda x: x is not None, data)) | |||
if return_dataset is True: | |||
ds = DataSet() | |||
for example in data_list: | |||
ds.append(Instance(words=example[0], | |||
pos_tags=example[1], | |||
heads=example[2], | |||
labels=example[3])) | |||
data_list = ds | |||
return data_list | |||
ds = DataSet() | |||
for example in data_list: | |||
ds.append(Instance(words=example[0], | |||
pos_tags=example[1], | |||
heads=example[2], | |||
labels=example[3])) | |||
return ds | |||
def get_one(self, sample): | |||
sample = list(map(list, zip(*sample))) | |||