Browse Source

remove device in batch

tags/v0.3.1^2
yunfan 6 years ago
parent
commit
9474ab4b34
2 changed files with 15 additions and 12 deletions
  1. +13
    -11
      fastNLP/core/batch.py
  2. +2
    -1
      fastNLP/core/trainer.py

+ 13
- 11
fastNLP/core/batch.py View File

@@ -20,8 +20,7 @@ class Batch(object):
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored.
""" """


def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False,
device='cpu'):
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.sampler = sampler self.sampler = sampler
@@ -32,8 +31,6 @@ class Batch(object):
self.cur_batch_indices = None self.cur_batch_indices = None
self.prefetch = prefetch self.prefetch = prefetch
self.lengths = 0 self.lengths = 0
if not as_numpy:
self.device = device if isinstance(device, torch.device) else torch.device(device)


def fetch_one(self): def fetch_one(self):
if self.curidx >= len(self.idx_list): if self.curidx >= len(self.idx_list):
@@ -50,7 +47,6 @@ class Batch(object):
batch = field.get(indices) batch = field.get(indices)
if not self.as_numpy and field.padder is not None: if not self.as_numpy and field.padder is not None:
batch = to_tensor(batch, field.dtype) batch = to_tensor(batch, field.dtype)
batch = batch.to(self.device)
if field.is_target: if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
if field.is_input: if field.is_input:
@@ -119,12 +115,18 @@ def run_batch_iter(batch):
fetch_p.start() fetch_p.start()
# print('fork fetch process') # print('fork fetch process')
while 1: while 1:
res = q.get()
q.task_done()
# print('get fetched')
if res is None:
break
yield res
try:
res = q.get(timeout=1)
q.task_done()
# print('get fetched')
if res is None:
break
yield res
except Exception as e:
if fetch_p.is_alive():
continue
else:
break
fetch_p.terminate() fetch_p.terminate()
fetch_p.join() fetch_p.join()
# print('iter done') # print('iter done')


+ 2
- 1
fastNLP/core/trainer.py View File

@@ -229,12 +229,13 @@ class Trainer(object):
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch, device=self._model_device)
prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
self.callback_manager.before_epoch(epoch, self.n_epochs) self.callback_manager.before_epoch(epoch, self.n_epochs)
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices) self.callback_manager.before_batch(batch_x, batch_y, indices)


Loading…
Cancel
Save