|
@@ -109,6 +109,7 @@ class DistTrainer(): |
|
|
:param kwargs: 支持配置可选参数 |
|
|
:param kwargs: 支持配置可选参数 |
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
|
|
|
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 |
|
|
""" |
|
|
""" |
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
if device == 'auto': |
|
|
if device == 'auto': |
|
@@ -172,17 +173,14 @@ class DistTrainer(): |
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
self.n_steps = self._get_n_steps() |
|
|
self.n_steps = self._get_n_steps() |
|
|
|
|
|
|
|
|
if 'test_use_tqdm' in kwargs: |
|
|
|
|
|
test_use_tqdm = kwargs.get('test_use_tqdm') |
|
|
|
|
|
else: |
|
|
|
|
|
test_use_tqdm = self.use_tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) |
|
|
|
|
|
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) |
|
|
# for evaluation, only run eval on master proc |
|
|
# for evaluation, only run eval on master proc |
|
|
if dev_data and metrics: |
|
|
if dev_data and metrics: |
|
|
cb = _TesterCallback( |
|
|
cb = _TesterCallback( |
|
|
dev_data, model, metrics, |
|
|
dev_data, model, metrics, |
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), |
|
|
|
|
|
use_tqdm=test_use_tqdm) |
|
|
|
|
|
|
|
|
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), |
|
|
|
|
|
use_tqdm=self.test_use_tqdm) |
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
|
# Setup logging |
|
|
# Setup logging |
|
@@ -379,11 +377,11 @@ class DistTrainer(): |
|
|
|
|
|
|
|
|
self.callback_manager.on_batch_end() |
|
|
self.callback_manager.on_batch_end() |
|
|
|
|
|
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0): |
|
|
|
|
|
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): |
|
|
self._do_validation() |
|
|
self._do_validation() |
|
|
|
|
|
|
|
|
# ================= mini-batch end ==================== # |
|
|
# ================= mini-batch end ==================== # |
|
|
if self.validate_every < 0: |
|
|
|
|
|
|
|
|
if self.validate_every < 0 and len(self.test_manager.callbacks): |
|
|
self._do_validation() |
|
|
self._do_validation() |
|
|
|
|
|
|
|
|
# lr decay; early stopping |
|
|
# lr decay; early stopping |
|
|