|
|
@@ -9,6 +9,7 @@ from tqdm import tqdm |
|
|
|
import logging |
|
|
|
import time |
|
|
|
from datetime import datetime, timedelta |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
from .batch import DataSetIter, BatchIter |
|
|
|
from .callback import DistCallbackManager, CallbackException, TesterCallback |
|
|
@@ -45,10 +46,12 @@ class DistTrainer(): |
|
|
|
callbacks_all=None, callbacks_master=None, |
|
|
|
batch_size_per_gpu=8, n_epochs=1, |
|
|
|
num_data_workers=1, drop_last=False, |
|
|
|
dev_data=None, metrics=None, |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
|
log_path=None, |
|
|
|
save_every=-1, save_path=None, device='auto', |
|
|
|
fp16='', backend=None, init_method=None): |
|
|
|
fp16='', backend=None, init_method=None, |
|
|
|
find_unused_parameters=True): |
|
|
|
|
|
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
|
if device == 'auto': |
|
|
@@ -87,6 +90,7 @@ class DistTrainer(): |
|
|
|
self.callback_manager = DistCallbackManager( |
|
|
|
env={"trainer": self}, callbacks_all=callbacks_all, |
|
|
|
callbacks_master=callbacks_master) |
|
|
|
self.metric_key = metric_key |
|
|
|
|
|
|
|
model.to(self.device) |
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
@@ -103,8 +107,13 @@ class DistTrainer(): |
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
if find_unused_parameters: |
|
|
|
# to support old version |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, find_unused_parameters=find_unused_parameters) |
|
|
|
else: |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
self.optimizer = optimizer |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
@@ -127,7 +136,8 @@ class DistTrainer(): |
|
|
|
self.cp_save_path = None |
|
|
|
|
|
|
|
# use INFO in the master, WARN for others |
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
|
|
logging.basicConfig(filename=log_path, |
|
|
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
|
|
datefmt='%m/%d/%Y %H:%M:%S', |
|
|
|
level=logging.INFO if self.is_master else logging.WARN) |
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
@@ -272,7 +282,15 @@ class DistTrainer(): |
|
|
|
|
|
|
|
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or |
|
|
|
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): |
|
|
|
self.callback_manager.on_validation() |
|
|
|
self.callback_manager.on_valid_begin() |
|
|
|
eval_res = self.callback_manager.on_validation() |
|
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
|
if len(eval_res): |
|
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
|
else: |
|
|
|
eval_res, is_better = None, None |
|
|
|
self.callback_manager.on_valid_end( |
|
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
if self.cp_save_path and \ |
|
|
|