Browse Source

dist_trainer for fp16

tags/v1.0.0alpha
yunfan 3 years ago
parent
commit
f711d3070a
2 changed files with 47 additions and 39 deletions
  1. +40
    -36
      fastNLP/core/dist_trainer.py
  2. +7
    -3
      tests/core/test_dist_trainer.py

+ 40
- 36
fastNLP/core/dist_trainer.py View File

@@ -29,14 +29,10 @@ from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _build_fp16_env
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device


try:
from apex import amp
except:
amp = None

__all__ = [ __all__ = [
'get_local_rank', 'get_local_rank',
'DistTrainer', 'DistTrainer',
@@ -72,7 +68,7 @@ class DistTrainer():
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
save_path=None, device='auto', save_path=None, device='auto',
fp16='', use_tqdm=True, **kwargs):
fp16=False, use_tqdm=True, **kwargs):
r""" r"""


:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -103,12 +99,15 @@ class DistTrainer():
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param str device: 指定 device,可以是 gpu,cpu 或 auto :param str device: 指定 device,可以是 gpu,cpu 或 auto
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度
:param bool fp16: 指定是否使用半精度训练
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param kwargs: 支持配置可选参数 :param kwargs: 支持配置可选参数
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
Sampler test_sampler: 在evaluate的时候使用的sampler Sampler test_sampler: 在evaluate的时候使用的sampler
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
bool test_use_fp16: test时使用fp16
bool set_grad_to_none: zero_grad时将grad设为None而不是0
GradScaler gradscaler: 自定义的梯度 scaler
""" """
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto': if device == 'auto':
@@ -147,14 +146,19 @@ class DistTrainer():
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm


model.to(self.device) model.to(self.device)
optimizer = self._get_optimizer(optimizer)


# init fp16, must before DataParallel init # init fp16, must before DataParallel init
if len(self.fp16):
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']"
_check_fp16()
assert device == 'cuda', "Amp requires cuda device"
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
self.auto_cast = autocast
user_grad_scaler = getattr(kwargs, 'gradscaler', None)
if user_grad_scaler is not None:
assert self.fp16, "must set fp16=True to enable gradscaler"
grad_scaler = user_grad_scaler
else:
grad_scaler = GradScaler()
self.grad_scaler = grad_scaler

self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True)


# init DataParallel # init DataParallel
if parse_version(torch.__version__)>=parse_version('1.1'): if parse_version(torch.__version__)>=parse_version('1.1'):
@@ -165,6 +169,7 @@ class DistTrainer():
output_device=self.local_rank) output_device=self.local_rank)
self.model = self.ddp_model.module self.model = self.ddp_model.module


optimizer = self._get_optimizer(optimizer)
self.optimizer = optimizer self.optimizer = optimizer
if isinstance(self.train_data, DataSet): if isinstance(self.train_data, DataSet):
self.sampler = DistributedSampler(self.train_data) self.sampler = DistributedSampler(self.train_data)
@@ -197,11 +202,9 @@ class DistTrainer():
self.logger = logger self.logger = logger
self.logger.info("Setup Distributed Trainer") self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))
os.getpid(), self.rank, self.local_rank, self.device, self.fp16))
self.logger.info("Num of processes: {}".format(self.world_size)) self.logger.info("Num of processes: {}".format(self.world_size))
self.logger.info("Use device: {}".format(device)) self.logger.info("Use device: {}".format(device))
self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None))


def _maybe_no_sync(self): def _maybe_no_sync(self):
""" """
@@ -343,28 +346,20 @@ class DistTrainer():
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.ddp_model, batch_x)
with self.auto_cast():
prediction = self._data_forward(self.ddp_model, batch_x)
# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)


# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)
if self.update_every > 1:
loss = loss / self.update_every
avg_loss += loss.item()
avg_loss += loss.detach()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)

# with self._maybe_no_sync():
if self.fp16:
with amp.scale_loss(loss, self.optimizer) as scale_loss:
scale_loss.backward()
else:
loss.backward()

self.grad_scaler.scale(loss).backward()
self.callback_manager.on_backward_end() self.callback_manager.on_backward_end()
self._update()
if self.step % self.update_every == 0:
self._update()
self.callback_manager.on_step_end() self.callback_manager.on_step_end()


if self.step % self.print_every == 0: if self.step % self.print_every == 0:
@@ -390,13 +385,22 @@ class DistTrainer():
self.pbar = None self.pbar = None
# ============ tqdm end ============== # # ============ tqdm end ============== #


def _clear_grad_opt(self, optimizer):
if self.set_grad_to_none:
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad = None
else:
optimizer.zero_grad()

def _update(self): def _update(self):
r"""Perform weight update on a model. r"""Perform weight update on a model.


""" """
if self.step % self.update_every == 0:
self.optimizer.step()
self.ddp_model.zero_grad()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
self._clear_grad_opt(self.optimizer)


def _data_forward(self, network, x): def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x) x = _build_args(self._forward_func, **x)


+ 7
- 3
tests/core/test_dist_trainer.py View File

@@ -6,13 +6,14 @@ from argparse import ArgumentParser


import numpy as np import numpy as np
import torch.cuda import torch.cuda
import torch.distributed as dist


from fastNLP import AccuracyMetric from fastNLP import AccuracyMetric
from fastNLP import CrossEntropyLoss, BCELoss from fastNLP import CrossEntropyLoss, BCELoss
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SGD from fastNLP import SGD
from fastNLP.core.callback import EchoCallback
from fastNLP.core.callback import EchoCallback, GradientClipCallback
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier


@@ -103,7 +104,7 @@ class TestDistTrainer(unittest.TestCase):
model=model, train_data=data_set, optimizer=SGD(lr=0.1), model=model, train_data=data_set, optimizer=SGD(lr=0.1),
loss=CrossEntropyLoss(pred="predict", target="y"), loss=CrossEntropyLoss(pred="predict", target="y"),
batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
fp16='O1'
fp16=True
) )
trainer.train() trainer.train()
""" """
@@ -113,18 +114,20 @@ class TestDistTrainer(unittest.TestCase):
shutil.rmtree(self.save_path) shutil.rmtree(self.save_path)
def run3(self): def run3(self):
# test callbacks, especially clip-norm
set_rng_seed(100) set_rng_seed(100)
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = DistTrainer( trainer = DistTrainer(
data_set, model, optimizer=None, data_set, model, optimizer=None,
loss=BCELoss(pred="predict", target="y"), loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50, n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')],
callbacks_all=[GradientClipCallback()],
callbacks_master=[EchoCallback('callbacks_master')] callbacks_master=[EchoCallback('callbacks_master')]
) )
trainer.train() trainer.train()
def run4(self): def run4(self):
# test metrics, save, and others
set_rng_seed(100) set_rng_seed(100)
data_set, model = prepare_env() data_set, model = prepare_env()
@@ -173,4 +176,5 @@ if __name__ == '__main__':
parser.add_argument('--test', type=int) parser.add_argument('--test', type=int)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
if args.test and hasattr(runner, 'run%s' % args.test): if args.test and hasattr(runner, 'run%s' % args.test):
dist.init_process_group("nccl")
getattr(runner, 'run%s' % args.test)() getattr(runner, 'run%s' % args.test)()

Loading…
Cancel
Save