Browse Source

add batch device

tags/v0.3.1^2
yunfan 6 years ago
parent
commit
a37de4344d
2 changed files with 33 additions and 10 deletions
  1. +25
    -6
      fastNLP/core/batch.py
  2. +8
    -4
      fastNLP/core/trainer.py

+ 25
- 6
fastNLP/core/batch.py View File

@@ -16,10 +16,12 @@ class Batch(object):
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.

:param bool prefetch: If True, use multiprocessing to fetch next batch when training.
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored.
"""

def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False):
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False,
device='cpu'):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
@@ -28,6 +30,10 @@ class Batch(object):
self.curidx = 0
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None
self.prefetch = prefetch
self.lengths = 0
if not as_numpy:
self.device = device if isinstance(device, torch.device) else torch.device(device)

def fetch_one(self):
if self.curidx >= len(self.idx_list):
@@ -44,6 +50,7 @@ class Batch(object):
batch = field.get(indices)
if not self.as_numpy and field.padder is not None:
batch = to_tensor(batch, field.dtype)
batch = batch.to(self.device)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
@@ -57,7 +64,21 @@ class Batch(object):
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process
:return:
"""
return run_batch_iter(self)
if self.prefetch:
return run_batch_iter(self)
def batch_iter():
self.init_iter()
while 1:
res = self.fetch_one()
if res is None:
break
yield res
return batch_iter()

def init_iter(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()

def __len__(self):
return self.num_batches
@@ -78,9 +99,7 @@ def to_tensor(batch, dtype):


def run_fetch(batch, q):
batch.idx_list = batch.sampler(batch.dataset)
batch.curidx = 0
batch.lengths = batch.dataset.get_length()
batch.init_iter()
# print('start fetch')
while 1:
res = batch.fetch_one()


+ 8
- 4
fastNLP/core/trainer.py View File

@@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature
class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0,
use_tqdm=True, use_cuda=False, callbacks=None):
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
@@ -127,6 +127,8 @@ class Trainer(object):
self.best_dev_perf = None
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)

if isinstance(optimizer, torch.optim.Optimizer):
@@ -247,7 +249,9 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
keep_process=True)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
@@ -257,7 +261,7 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.use_cuda) # pin_memory, use non_blockling.
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
prediction = self._data_forward(self.model, batch_x)

# edit prediction


Loading…
Cancel
Save