Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
cbbfc18149
5 changed files with 103 additions and 71 deletions
  1. +5
    -1
      fastNLP/__init__.py
  2. +8
    -3
      fastNLP/core/__init__.py
  3. +8
    -0
      fastNLP/core/_logger.py
  4. +4
    -4
      fastNLP/core/callback.py
  5. +78
    -63
      fastNLP/core/dist_trainer.py

+ 5
- 1
fastNLP/__init__.py View File

@@ -24,6 +24,9 @@ __all__ = [
"Trainer",
"Tester",

"DistTrainer",
"get_local_rank",
"Callback",
"GradientClipCallback",
@@ -75,7 +78,8 @@ __all__ = [
"cache_results",
'logger'
'logger',
"init_logger_dist",
]
__version__ = '0.5.0'



+ 8
- 3
fastNLP/core/__init__.py View File

@@ -33,12 +33,16 @@ __all__ = [
"Tester",
"Trainer",

"DistTrainer",
"get_local_rank",

"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"init_logger_dist",

"Callback",
"GradientClipCallback",
"EarlyStopCallback",
@@ -81,7 +85,7 @@ __all__ = [
"Sampler",
]

from ._logger import logger
from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \
@@ -100,3 +104,4 @@ from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from .collate_fn import ConcatCollateFn
from .dist_trainer import DistTrainer, get_local_rank

+ 8
- 0
fastNLP/core/_logger.py View File

@@ -18,6 +18,7 @@ logger.set_stdout('tqdm', level='WARN')

__all__ = [
'logger',
'init_logger_dist'
]

import logging
@@ -25,6 +26,7 @@ import logging.config
import os
import sys
import warnings
from torch import distributed as dist

ROOT_NAME = 'fastNLP'

@@ -169,3 +171,9 @@ def _get_logger(name=None, level='INFO'):


logger = _init_logger(path=None, level='INFO')


def init_logger_dist():
global logger
rank = dist.get_rank()
logger.setLevel(logging.INFO if rank == 0 else logging.WARNING)

+ 4
- 4
fastNLP/core/callback.py View File

@@ -114,6 +114,9 @@ class Callback(object):
self._trainer = None # 在Trainer内部被重新赋值
self._disabled = False

def __repr__(self):
return self.__class__.__name__

@property
def trainer(self):
r"""
@@ -1157,9 +1160,6 @@ class EchoCallback(Callback):
class _TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(_TesterCallback, self).__init__()
if hasattr(model, 'module'):
# for data parallel model
model = model.module
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
@@ -1183,7 +1183,7 @@ class _TesterCallback(Callback):

@staticmethod
def _get_score(metric_dict, key):
for metric in metric_dict.items():
for metric in metric_dict.values():
if key in metric:
return metric[key]
return None


+ 78
- 63
fastNLP/core/dist_trainer.py View File

@@ -9,16 +9,18 @@ import os
import time
from datetime import datetime

import contextlib
import torch
import torch.cuda
import torch.distributed as dist
import torch.optim
from torch.serialization import default_restore_location
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from ._logger import logger
from ._logger import logger, init_logger_dist
from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException
from .callback import _TesterCallback
@@ -69,8 +71,8 @@ class DistTrainer():
num_workers=1, drop_last=False,
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None, use_tqdm=True):
save_path=None, device='auto',
fp16='', use_tqdm=True):
r"""

:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -98,20 +100,15 @@ class DistTrainer():
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param str device: 指定 device,可以是 gpu,cpu 或 auto
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。
:param backend: 指定分布式的backend,详情参考 pytorch 文档
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
"""
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if backend is None:
backend = 'nccl' if device == 'cuda' else 'gloo'

# init distributed
if device == 'cuda':
@@ -120,11 +117,11 @@ class DistTrainer():
else:
self.device = torch.device(device)

dist.init_process_group(backend=backend, init_method=init_method)
init_logger_dist()

self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process

self.model = model
self.train_data = train_data
self.batch_size_per_gpu = int(batch_size_per_gpu)
self.n_epochs = int(n_epochs)
@@ -133,12 +130,9 @@ class DistTrainer():
self.update_every = int(update_every)
self.print_every = int(print_every)
self.validate_every = int(validate_every)
self.save_every = int(save_every)
self.save_path = save_path
self.losser = _prepare_losser(loss)
self.fp16 = fp16
self.init_method = init_method
self.backend = backend
self.local_rank = get_local_rank()
self._forward_func = model.forward
self.callback_manager = DistCallbackManager(
@@ -160,11 +154,12 @@ class DistTrainer():

# init DataParallel
if parse_version(torch.__version__)>=parse_version('1.1'):
self.model = DDP(model, device_ids=[self.local_rank],
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank, find_unused_parameters=True)
else:
self.model = DDP(model, device_ids=[self.local_rank],
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
self.model = self.ddp_model.module

self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data)
@@ -177,18 +172,16 @@ class DistTrainer():
cb = _TesterCallback(
dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers)
self.test_manager.add_callback([cb], master=False)
self.test_manager.add_callback([cb], master=True)

# Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints')
self.cp_save_path = self.save_path
else:
self.cp_save_path = None

# use INFO in the master, WARN for others
logger.setLevel(logging.INFO if self.is_master else logging.WARNING)
self.logger = logger
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
@@ -198,6 +191,22 @@ class DistTrainer():
self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None))

def _maybe_no_sync(self):
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
i = self.step % self.update_every
if (
self.world_size > 1
and hasattr(self.ddp_model, "no_sync")
and i != 0
):
return self.ddp_model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager

def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int(
@@ -219,9 +228,9 @@ class DistTrainer():
if isinstance(optimizer, torch.optim.Optimizer):
return optimizer
elif isinstance(optimizer, Optimizer):
return optimizer.construct_from_pytorch(self.model.parameters())
return optimizer.construct_from_pytorch(self.ddp_model.parameters())
elif optimizer is None:
return torch.optim.Adam(self.model.parameters(), lr=4e-3)
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))

@@ -252,8 +261,10 @@ class DistTrainer():
self.logger.info("###### Training epochs started ######")
self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu)
self.logger.info('Num of steps per update: %d' % self.update_every)
self.logger.info('Total batch_size: %d'%
(self.batch_size_per_gpu * dist.get_world_size() * self.update_every))
self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks for all workers: {}".format(
len(self.callback_manager.callbacks_all)))
@@ -290,13 +301,14 @@ class DistTrainer():
self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
self.load_check_point('best')
self.load_check_point(self._best_save_name())
finally:
pass
dist.barrier()
return results

def _train(self):
dist.barrier()
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:
@@ -309,29 +321,32 @@ class DistTrainer():
pbar = self.pbar
avg_loss = 0
data_iterator = self.data_iterator
self.model.zero_grad()
self.ddp_model.zero_grad()
for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
self.callback_manager.on_epoch_begin()
for batch_x, batch_y in data_iterator:
self.model.train()
self.step += 1
self.ddp_model.train()
_move_dict_value_to_device(batch_x, batch_y, device=self.device)
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)
prediction = self._data_forward(self.ddp_model, batch_x)

# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y)
if self.update_every > 1:
loss = loss / self.update_every
avg_loss += loss.item()

# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss)

# with self._maybe_no_sync():
if self.fp16:
with amp.scale_loss(loss, self.optimizer) as scale_loss:
scale_loss.backward()
@@ -355,17 +370,10 @@ class DistTrainer():
if (self.validate_every > 0 and self.step % self.validate_every == 0):
self._do_validation()

if self.cp_save_path and \
self.save_every > 0 and \
self.step % self.save_every == 0:
self.save_check_point()

# ================= mini-batch end ==================== #
if self.validate_every < 0:
self._do_validation()

if self.save_every < 0 and self.cp_save_path:
self.save_check_point()
# lr decay; early stopping
self.callback_manager.on_epoch_end()
# =============== epochs end =================== #
@@ -379,7 +387,7 @@ class DistTrainer():
"""
if self.step % self.update_every == 0:
self.optimizer.step()
self.model.zero_grad()
self.ddp_model.zero_grad()

def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x)
@@ -406,44 +414,51 @@ class DistTrainer():
def save_check_point(self, name=None, only_params=False):
r"""保存当前模型"""
# only master save models
if name is None:
name = 'checkpoint-{}.bin'.format(self.step)
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, name)
self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.ddp_model.module
if only_params:
model_to_save = model_to_save.state_dict()
if self.is_master:
if name is None:
name = 'checkpoint-{}.bin'.format(self.step)
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, name)
self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.model.module
if only_params:
model_to_save = model_to_save.state_dict()
torch.save(model_to_save, path)

def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path)
model_load = torch.load(path, map_location='cpu')
model_load = torch.load(
path,
map_location=lambda s, l: default_restore_location(s, "cpu"))
if not isinstance(model_load, dict):
model_load = model_load.state_dict()
self.model.module.load_state_dict(model_load)
self.model.load_state_dict(model_load)

def _best_save_name(self):
return "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])

def _do_validation(self):
self.callback_manager.on_valid_begin()
# do evaluate on all nodes
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
# save better model on master node
if self.is_master and is_better is not None and self.cp_save_path:
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
self.save_check_point('best')
break
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
dist.barrier()
with self.ddp_model.no_sync():
# 因为模型参数不更新,可以关闭同步
self.callback_manager.on_valid_begin()
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
eval_res = eval_res[0]
is_better = is_better[0]
else:
eval_res, is_better = None, None
# logger.info('{}, {}'.format(eval_res, is_better))
# save better model on master node
if is_better is not None and self.cp_save_path:
if is_better:
self.save_check_point(self._best_save_name(), only_params=False)
dist.barrier()
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
self.ddp_model.train()

def close(self):
r"""关闭Trainer,销毁进程"""


Loading…
Cancel
Save