From 9474ab4b341c0c81c7350b37dcf1b06bc509b7bb Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 21 Jan 2019 22:28:31 +0800 Subject: [PATCH] remove device in batch --- fastNLP/core/batch.py | 24 +++++++++++++----------- fastNLP/core/trainer.py | 3 ++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index ead7087e..88d9185d 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -20,8 +20,7 @@ class Batch(object): :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. """ - def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False, - device='cpu'): + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler @@ -32,8 +31,6 @@ class Batch(object): self.cur_batch_indices = None self.prefetch = prefetch self.lengths = 0 - if not as_numpy: - self.device = device if isinstance(device, torch.device) else torch.device(device) def fetch_one(self): if self.curidx >= len(self.idx_list): @@ -50,7 +47,6 @@ class Batch(object): batch = field.get(indices) if not self.as_numpy and field.padder is not None: batch = to_tensor(batch, field.dtype) - batch = batch.to(self.device) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -119,12 +115,18 @@ def run_batch_iter(batch): fetch_p.start() # print('fork fetch process') while 1: - res = q.get() - q.task_done() - # print('get fetched') - if res is None: - break - yield res + try: + res = q.get(timeout=1) + q.task_done() + # print('get fetched') + if res is None: + break + yield res + except Exception as e: + if fetch_p.is_alive(): + continue + else: + break fetch_p.terminate() fetch_p.join() # print('iter done') diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8ca3d22a..8112af88 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -229,12 +229,13 @@ class Trainer(object): with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, - prefetch=self.prefetch, device=self._model_device) + prefetch=self.prefetch) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.before_epoch(epoch, self.n_epochs) for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices)