|
|
@@ -17,21 +17,30 @@ from tqdm import tqdm |
|
|
|
|
|
|
|
from ._logger import logger |
|
|
|
from .batch import DataSetIter, BatchIter |
|
|
|
from .callback import DistCallbackManager, CallbackException, TesterCallback |
|
|
|
from .callback import DistCallbackManager, CallbackException, _TesterCallback |
|
|
|
from .dataset import DataSet |
|
|
|
from .losses import _prepare_losser |
|
|
|
from .optimizer import Optimizer |
|
|
|
from .utils import _build_args |
|
|
|
from .utils import _get_func_signature |
|
|
|
from .utils import _move_dict_value_to_device |
|
|
|
from .utils import _check_fp16 |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
from apex import amp |
|
|
|
except: |
|
|
|
amp = None |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'get_local_rank', |
|
|
|
'DistTrainer', |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def get_local_rank(): |
|
|
|
""" |
|
|
|
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 |
|
|
|
""" |
|
|
|
if 'LOCAL_RANK' in os.environ: |
|
|
|
return int(os.environ['LOCAL_RANK']) |
|
|
|
from argparse import ArgumentParser |
|
|
@@ -46,7 +55,10 @@ def get_local_rank(): |
|
|
|
|
|
|
|
class DistTrainer(): |
|
|
|
""" |
|
|
|
Distributed Trainer that support distributed and mixed precision training |
|
|
|
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 |
|
|
|
|
|
|
|
Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, |
|
|
|
请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等) |
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, optimizer=None, loss=None, |
|
|
|
callbacks_all=None, callbacks_master=None, |
|
|
@@ -55,8 +67,43 @@ class DistTrainer(): |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
|
save_every=-1, save_path=None, device='auto', |
|
|
|
fp16='', backend=None, init_method=None): |
|
|
|
fp16='', backend=None, init_method=None, use_tqdm=True): |
|
|
|
""" |
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
|
:param nn.modules model: 待训练的模型 |
|
|
|
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 |
|
|
|
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` |
|
|
|
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 |
|
|
|
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` |
|
|
|
:param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。 |
|
|
|
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` |
|
|
|
:param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。 |
|
|
|
:param int n_epochs: 需要优化迭代多少次。 |
|
|
|
:param num_workers: int, 有多少个线程来进行数据pad处理。 |
|
|
|
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch |
|
|
|
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 |
|
|
|
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , |
|
|
|
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 |
|
|
|
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, |
|
|
|
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 |
|
|
|
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, |
|
|
|
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 |
|
|
|
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 |
|
|
|
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 |
|
|
|
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 |
|
|
|
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 |
|
|
|
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 |
|
|
|
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 |
|
|
|
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。 |
|
|
|
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 |
|
|
|
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 |
|
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
|
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 |
|
|
|
:param backend: 指定分布式的backend,详情参考 pytorch 文档 |
|
|
|
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档 |
|
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
|
""" |
|
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
|
if device == 'auto': |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
@@ -94,7 +141,9 @@ class DistTrainer(): |
|
|
|
self.callback_manager = DistCallbackManager( |
|
|
|
env={"trainer": self}, callbacks_all=callbacks_all, |
|
|
|
callbacks_master=callbacks_master) |
|
|
|
self.test_manager = DistCallbackManager(env={'trainer': self}) |
|
|
|
self.metric_key = metric_key |
|
|
|
self.use_tqdm = use_tqdm |
|
|
|
|
|
|
|
model.to(self.device) |
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
@@ -102,11 +151,7 @@ class DistTrainer(): |
|
|
|
# init fp16, must before DataParallel init |
|
|
|
if len(self.fp16): |
|
|
|
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" |
|
|
|
try: |
|
|
|
from apex import amp |
|
|
|
except ImportError: |
|
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
|
|
|
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." |
|
|
|
_check_fp16() |
|
|
|
assert device == 'cuda', "Amp requires cuda device" |
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) |
|
|
|
|
|
|
@@ -121,14 +166,15 @@ class DistTrainer(): |
|
|
|
self.optimizer = optimizer |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
self.n_steps = self._get_n_steps() |
|
|
|
|
|
|
|
# for evaluation, only run eval on master proc |
|
|
|
if dev_data and metrics: |
|
|
|
cb = TesterCallback( |
|
|
|
cb = _TesterCallback( |
|
|
|
dev_data, model, metrics, |
|
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers) |
|
|
|
self.callback_manager.add_callback([cb], master=True) |
|
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
# Setup logging |
|
|
|
dist.barrier() |
|
|
@@ -178,9 +224,27 @@ class DistTrainer(): |
|
|
|
|
|
|
|
@property |
|
|
|
def is_master(self): |
|
|
|
"""是否是主进程""" |
|
|
|
return self.rank == 0 |
|
|
|
|
|
|
|
def train(self, on_exception='auto'): |
|
|
|
def train(self, load_best_model=True, on_exception='auto'): |
|
|
|
""" |
|
|
|
使用该函数使Trainer开始训练。 |
|
|
|
|
|
|
|
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 |
|
|
|
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; |
|
|
|
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. |
|
|
|
:return dict: 返回一个字典类型的数据, |
|
|
|
内含以下内容:: |
|
|
|
|
|
|
|
seconds: float, 表示训练时长 |
|
|
|
以下三个内容只有在提供了dev_data的情况下会有。 |
|
|
|
best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, |
|
|
|
第二层的key为具体的Metric |
|
|
|
best_epoch: int,在第几个epoch取得的最佳值 |
|
|
|
best_step: int, 在第几个step(batch)更新取得的最佳值 |
|
|
|
|
|
|
|
""" |
|
|
|
try: |
|
|
|
self.logger.info("###### Training epochs started ######") |
|
|
|
self.logger.info('Total epochs: %d'% self.n_epochs) |
|
|
@@ -222,17 +286,22 @@ class DistTrainer(): |
|
|
|
results['seconds'] = round(time.time() - start_time, 2) |
|
|
|
self.logger.info("###### Train finished ######") |
|
|
|
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) |
|
|
|
return results |
|
|
|
if load_best_model: |
|
|
|
self.load_check_point('best_{}'.format(self.metric_key)) |
|
|
|
finally: |
|
|
|
self.close() |
|
|
|
pass |
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
def _train(self): |
|
|
|
if self.fp16: |
|
|
|
# skip check, done in __init__() |
|
|
|
from apex import amp |
|
|
|
if not self.use_tqdm: |
|
|
|
from .utils import _pseudo_tqdm as inner_tqdm |
|
|
|
else: |
|
|
|
inner_tqdm = tqdm |
|
|
|
|
|
|
|
self.step = 0 |
|
|
|
self.epoch = 0 |
|
|
|
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', |
|
|
|
self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', |
|
|
|
leave=False, dynamic_ncols=True, disable=not self.is_master) |
|
|
|
pbar = self.pbar |
|
|
|
avg_loss = 0 |
|
|
@@ -292,8 +361,8 @@ class DistTrainer(): |
|
|
|
if self.validate_every < 0: |
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
if self.save_every < 0 and self.cp_save_path: |
|
|
|
self.save_check_point() |
|
|
|
if self.save_every < 0 and self.cp_save_path: |
|
|
|
self.save_check_point() |
|
|
|
# lr decay; early stopping |
|
|
|
self.callback_manager.on_epoch_end() |
|
|
|
# =============== epochs end =================== # |
|
|
@@ -327,22 +396,35 @@ class DistTrainer(): |
|
|
|
loss = self.losser(predict, truth) |
|
|
|
if self.update_every > 1: |
|
|
|
loss = loss / self.update_every |
|
|
|
return loss.mean() |
|
|
|
if loss.dim() > 0: |
|
|
|
loss = loss.mean() |
|
|
|
return loss |
|
|
|
|
|
|
|
def save_check_point(self, only_params=False): |
|
|
|
def save_check_point(self, name=None, only_params=False): |
|
|
|
"""保存当前模型""" |
|
|
|
# only master save models |
|
|
|
if self.is_master: |
|
|
|
if name is None: |
|
|
|
name = 'checkpoint-{}.bin'.format(self.step) |
|
|
|
os.makedirs(self.cp_save_path, exist_ok=True) |
|
|
|
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) |
|
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
|
self.logger.info("Save checkpoint to {}".format(path)) |
|
|
|
model_to_save = self.model.module |
|
|
|
if only_params: |
|
|
|
model_to_save = model_to_save.state_dict() |
|
|
|
torch.save(model_to_save, path) |
|
|
|
|
|
|
|
def load_check_point(self, name): |
|
|
|
path = os.path.join(self.cp_save_path, name) |
|
|
|
self.logger.info('reload best model from %s', path) |
|
|
|
model_load = torch.load(path) |
|
|
|
if not isinstance(model_load, dict): |
|
|
|
model_load = model_load.state_dict() |
|
|
|
self.model.load_state_dict(model_load) |
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
|
self.callback_manager.on_valid_begin() |
|
|
|
eval_res = self.callback_manager.on_validation() |
|
|
|
eval_res = self.test_manager.on_valid_begin() |
|
|
|
eval_res = list(filter(lambda x: x is not None, eval_res)) |
|
|
|
if len(eval_res): |
|
|
|
eval_res, is_better = list(zip(*eval_res)) |
|
|
@@ -350,7 +432,16 @@ class DistTrainer(): |
|
|
|
eval_res, is_better = None, None |
|
|
|
self.callback_manager.on_valid_end( |
|
|
|
eval_res, self.metric_key, self.optimizer, is_better) |
|
|
|
|
|
|
|
# save better model |
|
|
|
for i, better_flag in enumerate(is_better): |
|
|
|
if better_flag: |
|
|
|
# TODO to support multiple datasets to evaluate |
|
|
|
name = 'best_{}'.format(self.metric_key) |
|
|
|
self.save_check_point(name) |
|
|
|
break |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
def close(self): |
|
|
|
"""关闭Trainer,销毁进程""" |
|
|
|
dist.destroy_process_group() |