Browse Source

update DistTrainer

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
35d0371955
1 changed files with 7 additions and 9 deletions
  1. +7
    -9
      fastNLP/core/dist_trainer.py

+ 7
- 9
fastNLP/core/dist_trainer.py View File

@@ -109,6 +109,7 @@ class DistTrainer():
:param kwargs: 支持配置可选参数 :param kwargs: 支持配置可选参数
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
Sampler test_sampler: 在evaluate的时候使用的sampler Sampler test_sampler: 在evaluate的时候使用的sampler
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
""" """
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto': if device == 'auto':
@@ -172,17 +173,14 @@ class DistTrainer():
self.batch_size = self.world_size * self.batch_size_per_gpu self.batch_size = self.world_size * self.batch_size_per_gpu
self.n_steps = self._get_n_steps() self.n_steps = self._get_n_steps()


if 'test_use_tqdm' in kwargs:
test_use_tqdm = kwargs.get('test_use_tqdm')
else:
test_use_tqdm = self.use_tqdm

self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
# for evaluation, only run eval on master proc # for evaluation, only run eval on master proc
if dev_data and metrics: if dev_data and metrics:
cb = _TesterCallback( cb = _TesterCallback(
dev_data, model, metrics, dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
use_tqdm=test_use_tqdm)
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
use_tqdm=self.test_use_tqdm)
self.test_manager.add_callback([cb], master=True) self.test_manager.add_callback([cb], master=True)


# Setup logging # Setup logging
@@ -379,11 +377,11 @@ class DistTrainer():


self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()


if (self.validate_every > 0 and self.step % self.validate_every == 0):
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks):
self._do_validation() self._do_validation()


# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #
if self.validate_every < 0:
if self.validate_every < 0 and len(self.test_manager.callbacks):
self._do_validation() self._do_validation()


# lr decay; early stopping # lr decay; early stopping


Loading…
Cancel
Save