|
|
@@ -117,6 +117,8 @@ class DistTrainer(): |
|
|
|
else: |
|
|
|
self.device = torch.device(device) |
|
|
|
|
|
|
|
init_logger_dist() |
|
|
|
|
|
|
|
self.world_size = dist.get_world_size() |
|
|
|
self.rank = dist.get_rank() # unique id for each process |
|
|
|
|
|
|
@@ -180,7 +182,6 @@ class DistTrainer(): |
|
|
|
else: |
|
|
|
self.cp_save_path = None |
|
|
|
# use INFO in the master, WARN for others |
|
|
|
init_logger_dist() |
|
|
|
self.logger = logger |
|
|
|
self.logger.info("Setup Distributed Trainer") |
|
|
|
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( |
|
|
@@ -307,6 +308,7 @@ class DistTrainer(): |
|
|
|
return results |
|
|
|
|
|
|
|
def _train(self): |
|
|
|
dist.barrier() |
|
|
|
if not self.use_tqdm: |
|
|
|
from .utils import _pseudo_tqdm as inner_tqdm |
|
|
|
else: |
|
|
|