Browse Source

[update] bugfix in dist_trainer

tags/v0.5.5
yunfan 5 years ago
parent
commit
8a084c4f52
2 changed files with 4 additions and 2 deletions
  1. +1
    -1
      fastNLP/core/_logger.py
  2. +3
    -1
      fastNLP/core/dist_trainer.py

+ 1
- 1
fastNLP/core/_logger.py View File

@@ -176,4 +176,4 @@ logger = _init_logger(path=None, level='INFO')
def init_logger_dist():
global logger
rank = dist.get_rank()
logger.setLevel(logging.INFO if rank else logging.WARNING)
logger.setLevel(logging.INFO if rank == 0 else logging.WARNING)

+ 3
- 1
fastNLP/core/dist_trainer.py View File

@@ -117,6 +117,8 @@ class DistTrainer():
else:
self.device = torch.device(device)

init_logger_dist()

self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process

@@ -180,7 +182,6 @@ class DistTrainer():
else:
self.cp_save_path = None
# use INFO in the master, WARN for others
init_logger_dist()
self.logger = logger
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
@@ -307,6 +308,7 @@ class DistTrainer():
return results

def _train(self):
dist.barrier()
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:


Loading…
Cancel
Save