| @@ -2,7 +2,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
| import torch.multiprocessing as mp | |||||
| class Batch(object): | class Batch(object): | ||||
| """Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
| @@ -29,15 +29,9 @@ class Batch(object): | |||||
| self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | ||||
| self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
| def __iter__(self): | |||||
| self.idx_list = self.sampler(self.dataset) | |||||
| self.curidx = 0 | |||||
| self.lengths = self.dataset.get_length() | |||||
| return self | |||||
| def __next__(self): | |||||
| def fetch_one(self): | |||||
| if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
| raise StopIteration | |||||
| return None | |||||
| else: | else: | ||||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
| batch_x, batch_y = {}, {} | batch_x, batch_y = {}, {} | ||||
| @@ -56,9 +50,15 @@ class Batch(object): | |||||
| batch_x[field_name] = batch | batch_x[field_name] = batch | ||||
| self.curidx = endidx | self.curidx = endidx | ||||
| return batch_x, batch_y | return batch_x, batch_y | ||||
| def __iter__(self): | |||||
| """ | |||||
| Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||||
| :return: | |||||
| """ | |||||
| return run_batch_iter(self) | |||||
| def __len__(self): | def __len__(self): | ||||
| return self.num_batches | return self.num_batches | ||||
| @@ -75,3 +75,38 @@ def to_tensor(batch, dtype): | |||||
| except: | except: | ||||
| pass | pass | ||||
| return batch | return batch | ||||
| def run_fetch(batch, q): | |||||
| batch.idx_list = batch.sampler(batch.dataset) | |||||
| batch.curidx = 0 | |||||
| batch.lengths = batch.dataset.get_length() | |||||
| # print('start fetch') | |||||
| while 1: | |||||
| res = batch.fetch_one() | |||||
| # print('fetch one') | |||||
| q.put(res) | |||||
| if res is None: | |||||
| # print('fetch done, waiting processing') | |||||
| q.join() | |||||
| break | |||||
| # print('fetch exit') | |||||
| def run_batch_iter(batch): | |||||
| q = mp.JoinableQueue(maxsize=10) | |||||
| fetch_p = mp.Process(target=run_fetch, args=(batch, q)) | |||||
| fetch_p.daemon = True | |||||
| fetch_p.start() | |||||
| # print('fork fetch process') | |||||
| while 1: | |||||
| res = q.get() | |||||
| q.task_done() | |||||
| # print('get fetched') | |||||
| if res is None: | |||||
| break | |||||
| yield res | |||||
| fetch_p.terminate() | |||||
| fetch_p.join() | |||||
| # print('iter done') | |||||
| @@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature | |||||
| class Trainer(object): | class Trainer(object): | ||||
| def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
| validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | ||||
| check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, | |||||
| timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): | |||||
| check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, | |||||
| use_tqdm=True, use_cuda=False, callbacks=None): | |||||
| """ | """ | ||||
| :param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
| :param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
| @@ -127,8 +127,6 @@ class Trainer(object): | |||||
| self.best_dev_perf = None | self.best_dev_perf = None | ||||
| self.sampler = sampler | self.sampler = sampler | ||||
| self.num_workers = num_workers | self.num_workers = num_workers | ||||
| self.pin_memory = pin_memory | |||||
| self.timeout = timeout | |||||
| self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
| if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
| @@ -249,9 +247,7 @@ class Trainer(object): | |||||
| len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
| with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
| avg_loss = 0 | avg_loss = 0 | ||||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
| num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, | |||||
| keep_process=True) | |||||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||||
| for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
| pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
| # early stopping | # early stopping | ||||
| @@ -261,7 +257,7 @@ class Trainer(object): | |||||
| # negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
| self.callback_manager.before_batch(batch_x, batch_y, indices) | self.callback_manager.before_batch(batch_x, batch_y, indices) | ||||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | ||||
| non_blocking=self.pin_memory) # pin_memory, use non_blockling. | |||||
| non_blocking=self.use_cuda) # pin_memory, use non_blockling. | |||||
| prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
| # edit prediction | # edit prediction | ||||
| @@ -876,7 +876,7 @@ class ConllPOSReader(object): | |||||
| class ConllxDataLoader(object): | class ConllxDataLoader(object): | ||||
| def load(self, path, return_dataset=False): | |||||
| def load(self, path): | |||||
| datalist = [] | datalist = [] | ||||
| with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
| sample = [] | sample = [] | ||||
| @@ -894,15 +894,13 @@ class ConllxDataLoader(object): | |||||
| data = [self.get_one(sample) for sample in datalist] | data = [self.get_one(sample) for sample in datalist] | ||||
| data_list = list(filter(lambda x: x is not None, data)) | data_list = list(filter(lambda x: x is not None, data)) | ||||
| if return_dataset is True: | |||||
| ds = DataSet() | |||||
| for example in data_list: | |||||
| ds.append(Instance(words=example[0], | |||||
| pos_tags=example[1], | |||||
| heads=example[2], | |||||
| labels=example[3])) | |||||
| data_list = ds | |||||
| return data_list | |||||
| ds = DataSet() | |||||
| for example in data_list: | |||||
| ds.append(Instance(words=example[0], | |||||
| pos_tags=example[1], | |||||
| heads=example[2], | |||||
| labels=example[3])) | |||||
| return ds | |||||
| def get_one(self, sample): | def get_one(self, sample): | ||||
| sample = list(map(list, zip(*sample))) | sample = list(map(list, zip(*sample))) | ||||