Browse Source

dist_trainer for fp16

tags/v1.0.0alpha
yunfan 3 years ago
parent
commit
f711d3070a
2 changed files with 47 additions and 39 deletions
  1. +40
    -36
      fastNLP/core/dist_trainer.py
  2. +7
    -3
      tests/core/test_dist_trainer.py

+ 40
- 36
fastNLP/core/dist_trainer.py View File

@@ -29,14 +29,10 @@ from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
from .utils import _build_args
from .utils import _build_fp16_env
from .utils import _get_func_signature
from .utils import _move_dict_value_to_device

try:
from apex import amp
except:
amp = None

__all__ = [
'get_local_rank',
'DistTrainer',
@@ -72,7 +68,7 @@ class DistTrainer():
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
save_path=None, device='auto',
fp16='', use_tqdm=True, **kwargs):
fp16=False, use_tqdm=True, **kwargs):
r"""

:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -103,12 +99,15 @@ class DistTrainer():
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param str device: 指定 device,可以是 gpu,cpu 或 auto
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度
:param bool fp16: 指定是否使用半精度训练
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param kwargs: 支持配置可选参数
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
Sampler test_sampler: 在evaluate的时候使用的sampler
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
bool test_use_fp16: test时使用fp16
bool set_grad_to_none: zero_grad时将grad设为None而不是0
GradScaler gradscaler: 自定义的梯度 scaler
"""
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
@@ -147,14 +146,19 @@ class DistTrainer():
self.use_tqdm = use_tqdm

model.to(self.device)
optimizer = self._get_optimizer(optimizer)

# init fp16, must before DataParallel init
if len(self.fp16):
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']"
_check_fp16()
assert device == 'cuda', "Amp requires cuda device"
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
self.auto_cast = autocast
user_grad_scaler = getattr(kwargs, 'gradscaler', None)
if user_grad_scaler is not None:
assert self.fp16, "must set fp16=True to enable gradscaler"
grad_scaler = user_grad_scaler
else:
grad_scaler = GradScaler()
self.grad_scaler = grad_scaler

self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True)

# init DataParallel
if parse_version(torch.__version__)>=parse_version('1.1'):
@@ -165,6 +169,7 @@ class DistTrainer():
output_device=self.local_rank)
self.model = self.ddp_model.module

optimizer = self._get_optimizer(optimizer)
self.optimizer = optimizer
if isinstance(self.train_data, DataSet):
self.sampler = DistributedSampler(self.train_data)
@@ -197,11 +202,9 @@ class DistTrainer():
self.logger = logger
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))
os.getpid(), self.rank, self.local_rank, self.device, self.fp16))
self.logger.info("Num of processes: {}".format(self.world_size))
self.logger.info("Use device: {}".format(device))
self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None))

def _maybe_no_sync(self):
"""
@@ -343,28 +346,20 @@ class DistTrainer():
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.ddp_model, batch_x)
with self.auto_cast():
prediction = self._data_forward(self.ddp_model, batch_x)
# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)

# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)
if self.update_every > 1:
loss = loss / self.update_every
avg_loss += loss.item()
avg_loss += loss.detach()

# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss)

# with self._maybe_no_sync():
if self.fp16:
with amp.scale_loss(loss, self.optimizer) as scale_loss:
scale_loss.backward()
else:
loss.backward()

self.grad_scaler.scale(loss).backward()
self.callback_manager.on_backward_end()
self._update()
if self.step % self.update_every == 0:
self._update()
self.callback_manager.on_step_end()

if self.step % self.print_every == 0:
@@ -390,13 +385,22 @@ class DistTrainer():
self.pbar = None
# ============ tqdm end ============== #

def _clear_grad_opt(self, optimizer):
if self.set_grad_to_none:
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad = None
else:
optimizer.zero_grad()

def _update(self):
r"""Perform weight update on a model.

"""
if self.step % self.update_every == 0:
self.optimizer.step()
self.ddp_model.zero_grad()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
self._clear_grad_opt(self.optimizer)

def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x)


+ 7
- 3
tests/core/test_dist_trainer.py View File

@@ -6,13 +6,14 @@ from argparse import ArgumentParser

import numpy as np
import torch.cuda
import torch.distributed as dist

from fastNLP import AccuracyMetric
from fastNLP import CrossEntropyLoss, BCELoss
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SGD
from fastNLP.core.callback import EchoCallback
from fastNLP.core.callback import EchoCallback, GradientClipCallback
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
from fastNLP.models.base_model import NaiveClassifier

@@ -103,7 +104,7 @@ class TestDistTrainer(unittest.TestCase):
model=model, train_data=data_set, optimizer=SGD(lr=0.1),
loss=CrossEntropyLoss(pred="predict", target="y"),
batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
fp16='O1'
fp16=True
)
trainer.train()
"""
@@ -113,18 +114,20 @@ class TestDistTrainer(unittest.TestCase):
shutil.rmtree(self.save_path)
def run3(self):
# test callbacks, especially clip-norm
set_rng_seed(100)
data_set, model = prepare_env()
trainer = DistTrainer(
data_set, model, optimizer=None,
loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')],
callbacks_all=[GradientClipCallback()],
callbacks_master=[EchoCallback('callbacks_master')]
)
trainer.train()
def run4(self):
# test metrics, save, and others
set_rng_seed(100)
data_set, model = prepare_env()
@@ -173,4 +176,5 @@ if __name__ == '__main__':
parser.add_argument('--test', type=int)
args, _ = parser.parse_known_args()
if args.test and hasattr(runner, 'run%s' % args.test):
dist.init_process_group("nccl")
getattr(runner, 'run%s' % args.test)()

Loading…
Cancel
Save