|
|
@@ -9,16 +9,18 @@ import os |
|
|
|
import time |
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
import contextlib |
|
|
|
import torch |
|
|
|
import torch.cuda |
|
|
|
import torch.distributed as dist |
|
|
|
import torch.optim |
|
|
|
from torch.serialization import default_restore_location |
|
|
|
from pkg_resources import parse_version |
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from ._logger import logger |
|
|
|
from ._logger import logger, init_logger_dist |
|
|
|
from .batch import DataSetIter, BatchIter |
|
|
|
from .callback import DistCallbackManager, CallbackException |
|
|
|
from .callback import _TesterCallback |
|
|
@@ -69,8 +71,8 @@ class DistTrainer(): |
|
|
|
num_workers=1, drop_last=False, |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
|
save_every=-1, save_path=None, device='auto', |
|
|
|
fp16='', backend=None, init_method=None, use_tqdm=True): |
|
|
|
save_path=None, device='auto', |
|
|
|
fp16='', use_tqdm=True): |
|
|
|
r""" |
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
@@ -98,20 +100,15 @@ class DistTrainer(): |
|
|
|
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 |
|
|
|
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 |
|
|
|
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 |
|
|
|
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 |
|
|
|
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 |
|
|
|
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 |
|
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
|
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 |
|
|
|
:param backend: 指定分布式的backend,详情参考 pytorch 文档 |
|
|
|
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 |
|
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
|
""" |
|
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
|
if device == 'auto': |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
if backend is None: |
|
|
|
backend = 'nccl' if device == 'cuda' else 'gloo' |
|
|
|
|
|
|
|
# init distributed |
|
|
|
if device == 'cuda': |
|
|
@@ -120,11 +117,11 @@ class DistTrainer(): |
|
|
|
else: |
|
|
|
self.device = torch.device(device) |
|
|
|
|
|
|
|
dist.init_process_group(backend=backend, init_method=init_method) |
|
|
|
init_logger_dist() |
|
|
|
|
|
|
|
self.world_size = dist.get_world_size() |
|
|
|
self.rank = dist.get_rank() # unique id for each process |
|
|
|
|
|
|
|
self.model = model |
|
|
|
self.train_data = train_data |
|
|
|
self.batch_size_per_gpu = int(batch_size_per_gpu) |
|
|
|
self.n_epochs = int(n_epochs) |
|
|
@@ -133,12 +130,9 @@ class DistTrainer(): |
|
|
|
self.update_every = int(update_every) |
|
|
|
self.print_every = int(print_every) |
|
|
|
self.validate_every = int(validate_every) |
|
|
|
self.save_every = int(save_every) |
|
|
|
self.save_path = save_path |
|
|
|
self.losser = _prepare_losser(loss) |
|
|
|
self.fp16 = fp16 |
|
|
|
self.init_method = init_method |
|
|
|
self.backend = backend |
|
|
|
self.local_rank = get_local_rank() |
|
|
|
self._forward_func = model.forward |
|
|
|
self.callback_manager = DistCallbackManager( |
|
|
@@ -160,11 +154,12 @@ class DistTrainer(): |
|
|
|
|
|
|
|
# init DataParallel |
|
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank, find_unused_parameters=True) |
|
|
|
else: |
|
|
|
self.model = DDP(model, device_ids=[self.local_rank], |
|
|
|
self.ddp_model = DDP(model, device_ids=[self.local_rank], |
|
|
|
output_device=self.local_rank) |
|
|
|
self.model = self.ddp_model.module |
|
|
|
|
|
|
|
self.optimizer = optimizer |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
@@ -177,18 +172,16 @@ class DistTrainer(): |
|
|
|
cb = _TesterCallback( |
|
|
|
dev_data, model, metrics, |
|
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers) |
|
|
|
self.test_manager.add_callback([cb], master=False) |
|
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
# Setup logging |
|
|
|
dist.barrier() |
|
|
|
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') |
|
|
|
if self.save_path: |
|
|
|
self.cp_save_path = os.path.join(self.save_path, 'checkpoints') |
|
|
|
self.cp_save_path = self.save_path |
|
|
|
else: |
|
|
|
self.cp_save_path = None |
|
|
|
|
|
|
|
# use INFO in the master, WARN for others |
|
|
|
logger.setLevel(logging.INFO if self.is_master else logging.WARNING) |
|
|
|
self.logger = logger |
|
|
|
self.logger.info("Setup Distributed Trainer") |
|
|
|
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( |
|
|
@@ -198,6 +191,22 @@ class DistTrainer(): |
|
|
|
self.logger.info("Training with fp16: {}, optimization level: {}".format( |
|
|
|
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) |
|
|
|
|
|
|
|
def _maybe_no_sync(self): |
|
|
|
""" |
|
|
|
Whenever *samples* contains more than one mini-batch, we |
|
|
|
want to accumulate gradients locally and only call |
|
|
|
all-reduce in the last backwards pass. |
|
|
|
""" |
|
|
|
i = self.step % self.update_every |
|
|
|
if ( |
|
|
|
self.world_size > 1 |
|
|
|
and hasattr(self.ddp_model, "no_sync") |
|
|
|
and i != 0 |
|
|
|
): |
|
|
|
return self.ddp_model.no_sync() |
|
|
|
else: |
|
|
|
return contextlib.ExitStack() # dummy contextmanager |
|
|
|
|
|
|
|
def _get_n_steps(self): |
|
|
|
batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
return (len(self.train_data) // batch_size + int( |
|
|
@@ -219,9 +228,9 @@ class DistTrainer(): |
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
|
return optimizer |
|
|
|
elif isinstance(optimizer, Optimizer): |
|
|
|
return optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
|
return optimizer.construct_from_pytorch(self.ddp_model.parameters()) |
|
|
|
elif optimizer is None: |
|
|
|
return torch.optim.Adam(self.model.parameters(), lr=4e-3) |
|
|
|
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) |
|
|
|
else: |
|
|
|
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) |
|
|
|
|
|
|
@@ -252,8 +261,10 @@ class DistTrainer(): |
|
|
|
self.logger.info("###### Training epochs started ######") |
|
|
|
self.logger.info('Total epochs: %d'% self.n_epochs) |
|
|
|
self.logger.info('Total steps: %d'% self.n_steps) |
|
|
|
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) |
|
|
|
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) |
|
|
|
self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu) |
|
|
|
self.logger.info('Num of steps per update: %d' % self.update_every) |
|
|
|
self.logger.info('Total batch_size: %d'% |
|
|
|
(self.batch_size_per_gpu * dist.get_world_size() * self.update_every)) |
|
|
|
self.logger.info('Total num of samples: %d'% len(self.train_data)) |
|
|
|
self.logger.info("Num of callbacks for all workers: {}".format( |
|
|
|
len(self.callback_manager.callbacks_all))) |
|
|
@@ -290,13 +301,14 @@ class DistTrainer(): |
|
|
|
self.logger.info("###### Train finished ######") |
|
|
|
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) |
|
|
|
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): |
|
|
|
self.load_check_point('best') |
|
|
|
self.load_check_point(self._best_save_name()) |
|
|
|
finally: |
|
|
|
pass |
|
|
|
dist.barrier() |
|
|
|
return results |
|
|
|
|
|
|
|
def _train(self): |
|
|
|
dist.barrier() |
|
|
|
if not self.use_tqdm: |
|
|
|
from .utils import _pseudo_tqdm as inner_tqdm |
|
|
|
else: |
|
|
@@ -309,29 +321,32 @@ class DistTrainer(): |
|
|
|
pbar = self.pbar |
|
|
|
avg_loss = 0 |
|
|
|
data_iterator = self.data_iterator |
|
|
|
self.model.zero_grad() |
|
|
|
self.ddp_model.zero_grad() |
|
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
|
self.epoch = epoch |
|
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
|
# early stopping |
|
|
|
self.callback_manager.on_epoch_begin() |
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
self.model.train() |
|
|
|
self.step += 1 |
|
|
|
self.ddp_model.train() |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self.device) |
|
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
prediction = self._data_forward(self.ddp_model, batch_x) |
|
|
|
|
|
|
|
# edit prediction |
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
if self.update_every > 1: |
|
|
|
loss = loss / self.update_every |
|
|
|
avg_loss += loss.item() |
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|
|
|
|
|
# with self._maybe_no_sync(): |
|
|
|
if self.fp16: |
|
|
|
with amp.scale_loss(loss, self.optimizer) as scale_loss: |
|
|
|
scale_loss.backward() |
|
|
@@ -355,17 +370,10 @@ class DistTrainer(): |
|
|
|
if (self.validate_every > 0 and self.step % self.validate_every == 0): |
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
if self.cp_save_path and \ |
|
|
|
self.save_every > 0 and \ |
|
|
|
self.step % self.save_every == 0: |
|
|
|
self.save_check_point() |
|
|
|
|
|
|
|
# ================= mini-batch end ==================== # |
|
|
|
if self.validate_every < 0: |
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
if self.save_every < 0 and self.cp_save_path: |
|
|
|
self.save_check_point() |
|
|
|
# lr decay; early stopping |
|
|
|
self.callback_manager.on_epoch_end() |
|
|
|
# =============== epochs end =================== # |
|
|
@@ -379,7 +387,7 @@ class DistTrainer(): |
|
|
|
""" |
|
|
|
if self.step % self.update_every == 0: |
|
|
|
self.optimizer.step() |
|
|
|
self.model.zero_grad() |
|
|
|
self.ddp_model.zero_grad() |
|
|
|
|
|
|
|
def _data_forward(self, network, x): |
|
|
|
x = _build_args(self._forward_func, **x) |
|
|
@@ -406,44 +414,51 @@ class DistTrainer(): |
|
|
|
def save_check_point(self, name=None, only_params=False): |
|
|
|
r"""保存当前模型""" |
|
|
|
# only master save models |
|
|
|
if name is None: |
|
|
|
name = 'checkpoint-{}.bin'.format(self.step) |
|
|
|
os.makedirs(self.cp_save_path, exist_ok=True) |
|
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
|
self.logger.info("Save checkpoint to {}".format(path)) |
|
|
|
model_to_save = self.ddp_model.module |
|
|
|
if only_params: |
|
|
|
model_to_save = model_to_save.state_dict() |
|
|
|
if self.is_master: |
|
|
|
if name is None: |
|
|
|
name = 'checkpoint-{}.bin'.format(self.step) |
|
|
|
os.makedirs(self.cp_save_path, exist_ok=True) |
|
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
|
self.logger.info("Save checkpoint to {}".format(path)) |
|
|
|
model_to_save = self.model.module |
|
|
|
if only_params: |
|
|
|
model_to_save = model_to_save.state_dict() |
|
|
|
torch.save(model_to_save, path) |
|
|
|
|
|
|
|
def load_check_point(self, name): |
|
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
|
self.logger.info('reload best model from %s', path) |
|
|
|
model_load = torch.load(path, map_location='cpu') |
|
|
|
model_load = torch.load( |
|
|
|
path, |
|
|
|
map_location=lambda s, l: default_restore_location(s, "cpu")) |
|
|
|
if not isinstance(model_load, dict): |
|
|
|
model_load = model_load.state_dict() |
|
|
|
self.model.module.load_state_dict(model_load) |
|
|
|
self.model.load_state_dict(model_load) |
|
|
|
|
|
|
|
def _best_save_name(self): |
|
|
|
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) |
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
|
self.callback_manager.on_valid_begin() |
|
|
|
# do evaluate on all nodes |
|
|
|
eval_res = self.test_manager.on_valid_begin() |
|
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
|
if len(eval_res): |
|
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
|
else: |
|
|
|
eval_res, is_better = None, None |
|
|
|
# save better model on master node |
|
|
|
if self.is_master and is_better is not None and self.cp_save_path: |
|
|
|
for i, better_flag in enumerate(is_better): |
|
|
|
if better_flag: |
|
|
|
# TODO to support multiple datasets to evaluate |
|
|
|
self.save_check_point('best') |
|
|
|
break |
|
|
|
self.callback_manager.on_valid_end( |
|
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
|
dist.barrier() |
|
|
|
with self.ddp_model.no_sync(): |
|
|
|
# 因为模型参数不更新,可以关闭同步 |
|
|
|
self.callback_manager.on_valid_begin() |
|
|
|
eval_res = self.test_manager.on_valid_begin() |
|
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
|
if len(eval_res): |
|
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
|
eval_res = eval_res[0] |
|
|
|
is_better = is_better[0] |
|
|
|
else: |
|
|
|
eval_res, is_better = None, None |
|
|
|
# logger.info('{}, {}'.format(eval_res, is_better)) |
|
|
|
# save better model on master node |
|
|
|
if is_better is not None and self.cp_save_path: |
|
|
|
if is_better: |
|
|
|
self.save_check_point(self._best_save_name(), only_params=False) |
|
|
|
dist.barrier() |
|
|
|
self.callback_manager.on_valid_end( |
|
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
|
self.ddp_model.train() |
|
|
|
|
|
|
|
def close(self): |
|
|
|
r"""关闭Trainer,销毁进程""" |
|
|
|