Browse Source

trainer根据syf的多进程batch进行修改

tags/v0.3.1^2
yh 6 years ago
parent
commit
47ec69ea96
1 changed files with 5 additions and 15 deletions
  1. +5
    -15
      fastNLP/core/trainer.py

+ 5
- 15
fastNLP/core/trainer.py View File

@@ -33,8 +33,8 @@ from fastNLP.core.utils import get_func_signature
class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True,
use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
@@ -58,12 +58,7 @@ class Trainer(object):

metric_key="-PPL" # language model gets better as perplexity gets smaller
:param BaseSampler sampler: method used to generate batch data.
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
:param pin_memory: bool, 默认为False. 当设置为True时,会使用锁页内存,可能导致内存占用变多。如果内存比较充足,
可以考虑设置为True进行加速, 当pin_memory为True时,默认使用non_blocking=True的方式将数据从cpu移动到gpu。
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
检测是否出现了batch产生阻塞的情况。
:param prefetch: bool, 是否使用额外的进程对产生batch数据。
:param bool use_tqdm: whether to use tqdm to show train progress.
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以
通过callback机制实现。
@@ -125,9 +120,7 @@ class Trainer(object):
self.best_dev_step = None
self.best_dev_perf = None
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)

if isinstance(optimizer, torch.optim.Optimizer):
@@ -236,8 +229,7 @@ class Trainer(object):
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
keep_process=True)
prefetch=self.prefetch, device=self._model_device)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
@@ -246,8 +238,6 @@ class Trainer(object):
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.pin_memory) # pin_memory, use non_blocking.
prediction = self._data_forward(self.model, batch_x)

# edit prediction


Loading…
Cancel
Save