From 35d0371955e24f638202356013726b0cb9ec976b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 5 Dec 2020 13:33:02 +0800 Subject: [PATCH] update DistTrainer --- fastNLP/core/dist_trainer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index a76d0a05..bd7ba423 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -109,6 +109,7 @@ class DistTrainer(): :param kwargs: 支持配置可选参数 bool test_use_tqdm: 在dev上验证的时候是否开启tqdm Sampler test_sampler: 在evaluate的时候使用的sampler + int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 """ assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" if device == 'auto': @@ -172,17 +173,14 @@ class DistTrainer(): self.batch_size = self.world_size * self.batch_size_per_gpu self.n_steps = self._get_n_steps() - if 'test_use_tqdm' in kwargs: - test_use_tqdm = kwargs.get('test_use_tqdm') - else: - test_use_tqdm = self.use_tqdm - + self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) + dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) # for evaluation, only run eval on master proc if dev_data and metrics: cb = _TesterCallback( dev_data, model, metrics, - batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), - use_tqdm=test_use_tqdm) + batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), + use_tqdm=self.test_use_tqdm) self.test_manager.add_callback([cb], master=True) # Setup logging @@ -379,11 +377,11 @@ class DistTrainer(): self.callback_manager.on_batch_end() - if (self.validate_every > 0 and self.step % self.validate_every == 0): + if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): self._do_validation() # ================= mini-batch end ==================== # - if self.validate_every < 0: + if self.validate_every < 0 and len(self.test_manager.callbacks): self._do_validation() # lr decay; early stopping