diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py index 043a97c2..9051f700 100644 --- a/fastNLP/core/_logger.py +++ b/fastNLP/core/_logger.py @@ -176,4 +176,4 @@ logger = _init_logger(path=None, level='INFO') def init_logger_dist(): global logger rank = dist.get_rank() - logger.setLevel(logging.INFO if rank else logging.WARNING) + logger.setLevel(logging.INFO if rank == 0 else logging.WARNING) diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 56d123f4..289b434d 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -117,6 +117,8 @@ class DistTrainer(): else: self.device = torch.device(device) + init_logger_dist() + self.world_size = dist.get_world_size() self.rank = dist.get_rank() # unique id for each process @@ -180,7 +182,6 @@ class DistTrainer(): else: self.cp_save_path = None # use INFO in the master, WARN for others - init_logger_dist() self.logger = logger self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( @@ -307,6 +308,7 @@ class DistTrainer(): return results def _train(self): + dist.barrier() if not self.use_tqdm: from .utils import _pseudo_tqdm as inner_tqdm else: