|
@@ -29,14 +29,10 @@ from .dataset import DataSet |
|
|
from .losses import _prepare_losser |
|
|
from .losses import _prepare_losser |
|
|
from .optimizer import Optimizer |
|
|
from .optimizer import Optimizer |
|
|
from .utils import _build_args |
|
|
from .utils import _build_args |
|
|
|
|
|
from .utils import _build_fp16_env |
|
|
from .utils import _get_func_signature |
|
|
from .utils import _get_func_signature |
|
|
from .utils import _move_dict_value_to_device |
|
|
from .utils import _move_dict_value_to_device |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from apex import amp |
|
|
|
|
|
except: |
|
|
|
|
|
amp = None |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
|
'get_local_rank', |
|
|
'get_local_rank', |
|
|
'DistTrainer', |
|
|
'DistTrainer', |
|
@@ -72,7 +68,7 @@ class DistTrainer(): |
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
save_path=None, device='auto', |
|
|
save_path=None, device='auto', |
|
|
fp16='', use_tqdm=True, **kwargs): |
|
|
|
|
|
|
|
|
fp16=False, use_tqdm=True, **kwargs): |
|
|
r""" |
|
|
r""" |
|
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
@@ -103,12 +99,15 @@ class DistTrainer(): |
|
|
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 |
|
|
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 |
|
|
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 |
|
|
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 |
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 |
|
|
|
|
|
|
|
|
:param bool fp16: 指定是否使用半精度训练。 |
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
:param kwargs: 支持配置可选参数 |
|
|
:param kwargs: 支持配置可选参数 |
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 |
|
|
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 |
|
|
|
|
|
bool test_use_fp16: test时使用fp16 |
|
|
|
|
|
bool set_grad_to_none: zero_grad时将grad设为None而不是0 |
|
|
|
|
|
GradScaler gradscaler: 自定义的梯度 scaler |
|
|
""" |
|
|
""" |
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
if device == 'auto': |
|
|
if device == 'auto': |
|
@@ -147,14 +146,19 @@ class DistTrainer(): |
|
|
self.use_tqdm = use_tqdm |
|
|
self.use_tqdm = use_tqdm |
|
|
|
|
|
|
|
|
model.to(self.device) |
|
|
model.to(self.device) |
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
|
|
|
|
|
|
|
|
|
# init fp16, must before DataParallel init |
|
|
# init fp16, must before DataParallel init |
|
|
if len(self.fp16): |
|
|
|
|
|
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" |
|
|
|
|
|
_check_fp16() |
|
|
|
|
|
assert device == 'cuda', "Amp requires cuda device" |
|
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) |
|
|
|
|
|
|
|
|
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) |
|
|
|
|
|
self.auto_cast = autocast |
|
|
|
|
|
user_grad_scaler = getattr(kwargs, 'gradscaler', None) |
|
|
|
|
|
if user_grad_scaler is not None: |
|
|
|
|
|
assert self.fp16, "must set fp16=True to enable gradscaler" |
|
|
|
|
|
grad_scaler = user_grad_scaler |
|
|
|
|
|
else: |
|
|
|
|
|
grad_scaler = GradScaler() |
|
|
|
|
|
self.grad_scaler = grad_scaler |
|
|
|
|
|
|
|
|
|
|
|
self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True) |
|
|
|
|
|
|
|
|
# init DataParallel |
|
|
# init DataParallel |
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
|
if parse_version(torch.__version__)>=parse_version('1.1'): |
|
@@ -165,6 +169,7 @@ class DistTrainer(): |
|
|
output_device=self.local_rank) |
|
|
output_device=self.local_rank) |
|
|
self.model = self.ddp_model.module |
|
|
self.model = self.ddp_model.module |
|
|
|
|
|
|
|
|
|
|
|
optimizer = self._get_optimizer(optimizer) |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
if isinstance(self.train_data, DataSet): |
|
|
if isinstance(self.train_data, DataSet): |
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
self.sampler = DistributedSampler(self.train_data) |
|
@@ -197,11 +202,9 @@ class DistTrainer(): |
|
|
self.logger = logger |
|
|
self.logger = logger |
|
|
self.logger.info("Setup Distributed Trainer") |
|
|
self.logger.info("Setup Distributed Trainer") |
|
|
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( |
|
|
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( |
|
|
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) |
|
|
|
|
|
|
|
|
os.getpid(), self.rank, self.local_rank, self.device, self.fp16)) |
|
|
self.logger.info("Num of processes: {}".format(self.world_size)) |
|
|
self.logger.info("Num of processes: {}".format(self.world_size)) |
|
|
self.logger.info("Use device: {}".format(device)) |
|
|
self.logger.info("Use device: {}".format(device)) |
|
|
self.logger.info("Training with fp16: {}, optimization level: {}".format( |
|
|
|
|
|
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) |
|
|
|
|
|
|
|
|
|
|
|
def _maybe_no_sync(self): |
|
|
def _maybe_no_sync(self): |
|
|
""" |
|
|
""" |
|
@@ -343,28 +346,20 @@ class DistTrainer(): |
|
|
indices = data_iterator.get_batch_indices() |
|
|
indices = data_iterator.get_batch_indices() |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
prediction = self._data_forward(self.ddp_model, batch_x) |
|
|
|
|
|
|
|
|
with self.auto_cast(): |
|
|
|
|
|
prediction = self._data_forward(self.ddp_model, batch_x) |
|
|
|
|
|
# edit prediction |
|
|
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
|
|
|
# edit prediction |
|
|
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
if self.update_every > 1: |
|
|
|
|
|
loss = loss / self.update_every |
|
|
|
|
|
avg_loss += loss.item() |
|
|
|
|
|
|
|
|
avg_loss += loss.detach() |
|
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|
|
|
|
|
|
# with self._maybe_no_sync(): |
|
|
|
|
|
if self.fp16: |
|
|
|
|
|
with amp.scale_loss(loss, self.optimizer) as scale_loss: |
|
|
|
|
|
scale_loss.backward() |
|
|
|
|
|
else: |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.grad_scaler.scale(loss).backward() |
|
|
self.callback_manager.on_backward_end() |
|
|
self.callback_manager.on_backward_end() |
|
|
|
|
|
|
|
|
self._update() |
|
|
|
|
|
|
|
|
if self.step % self.update_every == 0: |
|
|
|
|
|
self._update() |
|
|
self.callback_manager.on_step_end() |
|
|
self.callback_manager.on_step_end() |
|
|
|
|
|
|
|
|
if self.step % self.print_every == 0: |
|
|
if self.step % self.print_every == 0: |
|
@@ -390,13 +385,22 @@ class DistTrainer(): |
|
|
self.pbar = None |
|
|
self.pbar = None |
|
|
# ============ tqdm end ============== # |
|
|
# ============ tqdm end ============== # |
|
|
|
|
|
|
|
|
|
|
|
def _clear_grad_opt(self, optimizer): |
|
|
|
|
|
if self.set_grad_to_none: |
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
|
|
|
for p in group['params']: |
|
|
|
|
|
if p.grad is not None: |
|
|
|
|
|
p.grad = None |
|
|
|
|
|
else: |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
def _update(self): |
|
|
def _update(self): |
|
|
r"""Perform weight update on a model. |
|
|
r"""Perform weight update on a model. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
if self.step % self.update_every == 0: |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
self.ddp_model.zero_grad() |
|
|
|
|
|
|
|
|
self.grad_scaler.step(self.optimizer) |
|
|
|
|
|
self.grad_scaler.update() |
|
|
|
|
|
self._clear_grad_opt(self.optimizer) |
|
|
|
|
|
|
|
|
def _data_forward(self, network, x): |
|
|
def _data_forward(self, network, x): |
|
|
x = _build_args(self._forward_func, **x) |
|
|
x = _build_args(self._forward_func, **x) |
|
|