|
|
@@ -177,8 +177,13 @@ class DistTrainer(): |
|
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
self.n_steps = self._get_n_steps() |
|
|
|
|
|
|
|
self.dev_data = dev_data |
|
|
|
self.metrics = metrics |
|
|
|
self.test_use_tqdm = True |
|
|
|
self.kwargs = kwargs |
|
|
|
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) |
|
|
|
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) |
|
|
|
|
|
|
|
# for evaluation, only run eval on master proc |
|
|
|
if dev_data and metrics: |
|
|
|
cb = _TesterCallback( |
|
|
|