diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 3faab8c0..ead7087e 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -16,10 +16,12 @@ class Batch(object): :param int batch_size: the size of the batch :param Sampler sampler: a Sampler object :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. - + :param bool prefetch: If True, use multiprocessing to fetch next batch when training. + :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. """ - def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False, + device='cpu'): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler @@ -28,6 +30,10 @@ class Batch(object): self.curidx = 0 self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) self.cur_batch_indices = None + self.prefetch = prefetch + self.lengths = 0 + if not as_numpy: + self.device = device if isinstance(device, torch.device) else torch.device(device) def fetch_one(self): if self.curidx >= len(self.idx_list): @@ -44,6 +50,7 @@ class Batch(object): batch = field.get(indices) if not self.as_numpy and field.padder is not None: batch = to_tensor(batch, field.dtype) + batch = batch.to(self.device) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -57,7 +64,21 @@ class Batch(object): Iterate on dataset, fetch batch data. Fetch process don't block the iterate process :return: """ - return run_batch_iter(self) + if self.prefetch: + return run_batch_iter(self) + def batch_iter(): + self.init_iter() + while 1: + res = self.fetch_one() + if res is None: + break + yield res + return batch_iter() + + def init_iter(self): + self.idx_list = self.sampler(self.dataset) + self.curidx = 0 + self.lengths = self.dataset.get_length() def __len__(self): return self.num_batches @@ -78,9 +99,7 @@ def to_tensor(batch, dtype): def run_fetch(batch, q): - batch.idx_list = batch.sampler(batch.dataset) - batch.curidx = 0 - batch.lengths = batch.dataset.get_length() + batch.init_iter() # print('start fetch') while 1: res = batch.fetch_one() diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index faa0d0a2..a5861091 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, - use_tqdm=True, use_cuda=False, callbacks=None): + check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, + timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -127,6 +127,8 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler self.num_workers = num_workers + self.pin_memory = pin_memory + self.timeout = timeout self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): @@ -247,7 +249,9 @@ class Trainer(object): len(self.train_data) % self.batch_size != 0)) * self.n_epochs with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, + keep_process=True) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -257,7 +261,7 @@ class Trainer(object): # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, - non_blocking=self.use_cuda) # pin_memory, use non_blockling. + non_blocking=self.pin_memory) # pin_memory, use non_blockling. prediction = self._data_forward(self.model, batch_x) # edit prediction