From 8a084c4f52d54bd153bc4c74ca299f9ddde606d6 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 11 Apr 2020 17:01:17 +0800 Subject: [PATCH] [update] bugfix in dist_trainer --- fastNLP/core/_logger.py | 2 +- fastNLP/core/dist_trainer.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py index 043a97c2..9051f700 100644 --- a/fastNLP/core/_logger.py +++ b/fastNLP/core/_logger.py @@ -176,4 +176,4 @@ logger = _init_logger(path=None, level='INFO') def init_logger_dist(): global logger rank = dist.get_rank() - logger.setLevel(logging.INFO if rank else logging.WARNING) + logger.setLevel(logging.INFO if rank == 0 else logging.WARNING) diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 56d123f4..289b434d 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -117,6 +117,8 @@ class DistTrainer(): else: self.device = torch.device(device) + init_logger_dist() + self.world_size = dist.get_world_size() self.rank = dist.get_rank() # unique id for each process @@ -180,7 +182,6 @@ class DistTrainer(): else: self.cp_save_path = None # use INFO in the master, WARN for others - init_logger_dist() self.logger = logger self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( @@ -307,6 +308,7 @@ class DistTrainer(): return results def _train(self): + dist.barrier() if not self.use_tqdm: from .utils import _pseudo_tqdm as inner_tqdm else: