Browse Source

[update] bugfix in dist_trainer

tags/v0.5.5
yunfan 5 years ago
parent
commit
b49c694c62
5 changed files with 101 additions and 71 deletions
  1. +5
    -1
      fastNLP/__init__.py
  2. +8
    -3
      fastNLP/core/__init__.py
  3. +8
    -0
      fastNLP/core/_logger.py
  4. +4
    -4
      fastNLP/core/callback.py
  5. +76
    -63
      fastNLP/core/dist_trainer.py

+ 5
- 1
fastNLP/__init__.py View File

@@ -24,6 +24,9 @@ __all__ = [
"Trainer", "Trainer",
"Tester", "Tester",

"DistTrainer",
"get_local_rank",
"Callback", "Callback",
"GradientClipCallback", "GradientClipCallback",
@@ -75,7 +78,8 @@ __all__ = [
"cache_results", "cache_results",
'logger'
'logger',
"init_logger_dist",
] ]
__version__ = '0.5.0' __version__ = '0.5.0'




+ 8
- 3
fastNLP/core/__init__.py View File

@@ -33,12 +33,16 @@ __all__ = [
"Tester", "Tester",
"Trainer", "Trainer",

"DistTrainer",
"get_local_rank",

"cache_results", "cache_results",
"seq_len_to_mask", "seq_len_to_mask",
"get_seq_len", "get_seq_len",
"logger", "logger",
"init_logger_dist",

"Callback", "Callback",
"GradientClipCallback", "GradientClipCallback",
"EarlyStopCallback", "EarlyStopCallback",
@@ -81,7 +85,7 @@ __all__ = [
"Sampler", "Sampler",
] ]


from ._logger import logger
from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter, TorchLoaderIter from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \
@@ -100,3 +104,4 @@ from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .collate_fn import ConcatCollateFn from .collate_fn import ConcatCollateFn
from .dist_trainer import DistTrainer, get_local_rank

+ 8
- 0
fastNLP/core/_logger.py View File

@@ -18,6 +18,7 @@ logger.set_stdout('tqdm', level='WARN')


__all__ = [ __all__ = [
'logger', 'logger',
'init_logger_dist'
] ]


import logging import logging
@@ -25,6 +26,7 @@ import logging.config
import os import os
import sys import sys
import warnings import warnings
from torch import distributed as dist


ROOT_NAME = 'fastNLP' ROOT_NAME = 'fastNLP'


@@ -169,3 +171,9 @@ def _get_logger(name=None, level='INFO'):




logger = _init_logger(path=None, level='INFO') logger = _init_logger(path=None, level='INFO')


def init_logger_dist():
global logger
rank = dist.get_rank()
logger.setLevel(logging.INFO if rank else logging.WARNING)

+ 4
- 4
fastNLP/core/callback.py View File

@@ -114,6 +114,9 @@ class Callback(object):
self._trainer = None # 在Trainer内部被重新赋值 self._trainer = None # 在Trainer内部被重新赋值
self._disabled = False self._disabled = False


def __repr__(self):
return self.__class__.__name__

@property @property
def trainer(self): def trainer(self):
r""" r"""
@@ -1157,9 +1160,6 @@ class EchoCallback(Callback):
class _TesterCallback(Callback): class _TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(_TesterCallback, self).__init__() super(_TesterCallback, self).__init__()
if hasattr(model, 'module'):
# for data parallel model
model = model.module
self.tester = Tester(data, model, self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size, metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0) num_workers=num_workers, verbose=0)
@@ -1183,7 +1183,7 @@ class _TesterCallback(Callback):


@staticmethod @staticmethod
def _get_score(metric_dict, key): def _get_score(metric_dict, key):
for metric in metric_dict.items():
for metric in metric_dict.values():
if key in metric: if key in metric:
return metric[key] return metric[key]
return None return None


+ 76
- 63
fastNLP/core/dist_trainer.py View File

@@ -9,16 +9,18 @@ import os
import time import time
from datetime import datetime from datetime import datetime


import contextlib
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed as dist import torch.distributed as dist
import torch.optim import torch.optim
from torch.serialization import default_restore_location
from pkg_resources import parse_version from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm


from ._logger import logger
from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException from .callback import DistCallbackManager, CallbackException
from .callback import _TesterCallback from .callback import _TesterCallback
@@ -69,8 +71,8 @@ class DistTrainer():
num_workers=1, drop_last=False, num_workers=1, drop_last=False,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None, use_tqdm=True):
save_path=None, device='auto',
fp16='', use_tqdm=True):
r""" r"""


:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -98,20 +100,15 @@ class DistTrainer():
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param str device: 指定 device,可以是 gpu,cpu 或 auto :param str device: 指定 device,可以是 gpu,cpu 或 auto
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 :param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。
:param backend: 指定分布式的backend,详情参考 pytorch 文档
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
""" """
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto': if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
if backend is None:
backend = 'nccl' if device == 'cuda' else 'gloo'


# init distributed # init distributed
if device == 'cuda': if device == 'cuda':
@@ -120,11 +117,9 @@ class DistTrainer():
else: else:
self.device = torch.device(device) self.device = torch.device(device)


dist.init_process_group(backend=backend, init_method=init_method)
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process self.rank = dist.get_rank() # unique id for each process


self.model = model
self.train_data = train_data self.train_data = train_data
self.batch_size_per_gpu = int(batch_size_per_gpu) self.batch_size_per_gpu = int(batch_size_per_gpu)
self.n_epochs = int(n_epochs) self.n_epochs = int(n_epochs)
@@ -133,12 +128,9 @@ class DistTrainer():
self.update_every = int(update_every) self.update_every = int(update_every)
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self.save_every = int(save_every)
self.save_path = save_path self.save_path = save_path
self.losser = _prepare_losser(loss) self.losser = _prepare_losser(loss)
self.fp16 = fp16 self.fp16 = fp16
self.init_method = init_method
self.backend = backend
self.local_rank = get_local_rank() self.local_rank = get_local_rank()
self._forward_func = model.forward self._forward_func = model.forward
self.callback_manager = DistCallbackManager( self.callback_manager = DistCallbackManager(
@@ -160,11 +152,12 @@ class DistTrainer():


# init DataParallel # init DataParallel
if parse_version(torch.__version__)>=parse_version('1.1'): if parse_version(torch.__version__)>=parse_version('1.1'):
self.model = DDP(model, device_ids=[self.local_rank],
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank, find_unused_parameters=True) output_device=self.local_rank, find_unused_parameters=True)
else: else:
self.model = DDP(model, device_ids=[self.local_rank],
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank) output_device=self.local_rank)
self.model = self.ddp_model.module


self.optimizer = optimizer self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data) self.sampler = DistributedSampler(self.train_data)
@@ -177,18 +170,17 @@ class DistTrainer():
cb = _TesterCallback( cb = _TesterCallback(
dev_data, model, metrics, dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers) batch_size=batch_size_per_gpu, num_workers=num_workers)
self.test_manager.add_callback([cb], master=False)
self.test_manager.add_callback([cb], master=True)


# Setup logging # Setup logging
dist.barrier() dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path: if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints')
self.cp_save_path = self.save_path
else: else:
self.cp_save_path = None self.cp_save_path = None

# use INFO in the master, WARN for others # use INFO in the master, WARN for others
logger.setLevel(logging.INFO if self.is_master else logging.WARNING)
init_logger_dist()
self.logger = logger self.logger = logger
self.logger.info("Setup Distributed Trainer") self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
@@ -198,6 +190,22 @@ class DistTrainer():
self.logger.info("Training with fp16: {}, optimization level: {}".format( self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) len(self.fp16) > 0, self.fp16 if self.fp16 else None))


def _maybe_no_sync(self):
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
i = self.step % self.update_every
if (
self.world_size > 1
and hasattr(self.ddp_model, "no_sync")
and i != 0
):
return self.ddp_model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager

def _get_n_steps(self): def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int( return (len(self.train_data) // batch_size + int(
@@ -219,9 +227,9 @@ class DistTrainer():
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
return optimizer return optimizer
elif isinstance(optimizer, Optimizer): elif isinstance(optimizer, Optimizer):
return optimizer.construct_from_pytorch(self.model.parameters())
return optimizer.construct_from_pytorch(self.ddp_model.parameters())
elif optimizer is None: elif optimizer is None:
return torch.optim.Adam(self.model.parameters(), lr=4e-3)
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
else: else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))


@@ -252,8 +260,10 @@ class DistTrainer():
self.logger.info("###### Training epochs started ######") self.logger.info("###### Training epochs started ######")
self.logger.info('Total epochs: %d'% self.n_epochs) self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps) self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu)
self.logger.info('Num of steps per update: %d' % self.update_every)
self.logger.info('Total batch_size: %d'%
(self.batch_size_per_gpu * dist.get_world_size() * self.update_every))
self.logger.info('Total num of samples: %d'% len(self.train_data)) self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks for all workers: {}".format( self.logger.info("Num of callbacks for all workers: {}".format(
len(self.callback_manager.callbacks_all))) len(self.callback_manager.callbacks_all)))
@@ -290,7 +300,7 @@ class DistTrainer():
self.logger.info("###### Train finished ######") self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
self.load_check_point('best')
self.load_check_point(self._best_save_name())
finally: finally:
pass pass
dist.barrier() dist.barrier()
@@ -309,29 +319,32 @@ class DistTrainer():
pbar = self.pbar pbar = self.pbar
avg_loss = 0 avg_loss = 0
data_iterator = self.data_iterator data_iterator = self.data_iterator
self.model.zero_grad()
self.ddp_model.zero_grad()
for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
self.callback_manager.on_epoch_begin() self.callback_manager.on_epoch_begin()
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.model.train()
self.step += 1 self.step += 1
self.ddp_model.train()
_move_dict_value_to_device(batch_x, batch_y, device=self.device) _move_dict_value_to_device(batch_x, batch_y, device=self.device)
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)
prediction = self._data_forward(self.ddp_model, batch_x)


# edit prediction # edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction) self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
if self.update_every > 1:
loss = loss / self.update_every
avg_loss += loss.item() avg_loss += loss.item()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)


# with self._maybe_no_sync():
if self.fp16: if self.fp16:
with amp.scale_loss(loss, self.optimizer) as scale_loss: with amp.scale_loss(loss, self.optimizer) as scale_loss:
scale_loss.backward() scale_loss.backward()
@@ -355,17 +368,10 @@ class DistTrainer():
if (self.validate_every > 0 and self.step % self.validate_every == 0): if (self.validate_every > 0 and self.step % self.validate_every == 0):
self._do_validation() self._do_validation()


if self.cp_save_path and \
self.save_every > 0 and \
self.step % self.save_every == 0:
self.save_check_point()

# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #
if self.validate_every < 0: if self.validate_every < 0:
self._do_validation() self._do_validation()


if self.save_every < 0 and self.cp_save_path:
self.save_check_point()
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
@@ -379,7 +385,7 @@ class DistTrainer():
""" """
if self.step % self.update_every == 0: if self.step % self.update_every == 0:
self.optimizer.step() self.optimizer.step()
self.model.zero_grad()
self.ddp_model.zero_grad()


def _data_forward(self, network, x): def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x) x = _build_args(self._forward_func, **x)
@@ -406,44 +412,51 @@ class DistTrainer():
def save_check_point(self, name=None, only_params=False): def save_check_point(self, name=None, only_params=False):
r"""保存当前模型""" r"""保存当前模型"""
# only master save models # only master save models
if name is None:
name = 'checkpoint-{}.bin'.format(self.step)
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, name)
self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.ddp_model.module
if only_params:
model_to_save = model_to_save.state_dict()
if self.is_master: if self.is_master:
if name is None:
name = 'checkpoint-{}.bin'.format(self.step)
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, name)
self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.model.module
if only_params:
model_to_save = model_to_save.state_dict()
torch.save(model_to_save, path) torch.save(model_to_save, path)


def load_check_point(self, name): def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name) path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path) self.logger.info('reload best model from %s', path)
model_load = torch.load(path, map_location='cpu')
model_load = torch.load(
path,
map_location=lambda s, l: default_restore_location(s, "cpu"))
if not isinstance(model_load, dict): if not isinstance(model_load, dict):
model_load = model_load.state_dict() model_load = model_load.state_dict()
self.model.module.load_state_dict(model_load)
self.model.load_state_dict(model_load)

def _best_save_name(self):
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])


def _do_validation(self): def _do_validation(self):
self.callback_manager.on_valid_begin()
# do evaluate on all nodes
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
# save better model on master node
if self.is_master and is_better is not None and self.cp_save_path:
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
self.save_check_point('best')
break
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
dist.barrier()
with self.ddp_model.no_sync():
# 因为模型参数不更新,可以关闭同步
self.callback_manager.on_valid_begin()
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
eval_res = eval_res[0]
is_better = is_better[0]
else:
eval_res, is_better = None, None
# logger.info('{}, {}'.format(eval_res, is_better))
# save better model on master node
if is_better is not None and self.cp_save_path:
if is_better:
self.save_check_point(self._best_save_name(), only_params=False)
dist.barrier()
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
self.ddp_model.train()


def close(self): def close(self):
r"""关闭Trainer,销毁进程""" r"""关闭Trainer,销毁进程"""


Loading…
Cancel
Save