Browse Source

[add] docstring in batch, dist_trainer; [update] dist_trainer, callback

tags/v0.4.10
yunfan 5 years ago
parent
commit
02cfc9f421
5 changed files with 170 additions and 51 deletions
  1. +18
    -3
      fastNLP/core/batch.py
  2. +25
    -24
      fastNLP/core/callback.py
  3. +115
    -24
      fastNLP/core/dist_trainer.py
  4. +1
    -0
      fastNLP/core/trainer.py
  5. +11
    -0
      fastNLP/core/utils.py

+ 18
- 3
fastNLP/core/batch.py View File

@@ -193,13 +193,14 @@ class DataSetIter(BatchIter):
Default: ``None`` Default: ``None``
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`.
Default: ``False`` Default: ``False``
:param int num_workers: 使用多少个进程来预处理数据 :param int num_workers: 使用多少个进程来预处理数据
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout:
:param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数
""" """
assert isinstance(dataset, DataSet) assert isinstance(dataset, DataSet)
dataset = DataSetGetter(dataset, as_numpy) dataset = DataSetGetter(dataset, as_numpy)
@@ -220,12 +221,26 @@ class DataSetIter(BatchIter):


class TorchLoaderIter(BatchIter): class TorchLoaderIter(BatchIter):
""" """
与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。
与DataSetIter类似,但用于pytorch的DataSet对象。
通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。


""" """
def __init__(self, dataset, batch_size=1, sampler=None, def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None): timeout=0, worker_init_fn=None, collate_fn=None):
"""

:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.

Default: ``None``
:param int num_workers: 使用多少个进程来预处理数据
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数"""
assert len(dataset) > 0 assert len(dataset) > 0
ins = dataset[0] ins = dataset[0]
assert len(ins) == 2 and \ assert len(ins) == 2 and \


+ 25
- 24
fastNLP/core/callback.py View File

@@ -87,12 +87,18 @@ except:
from .dataset import DataSet from .dataset import DataSet
from .tester import Tester from .tester import Tester
from ._logger import logger from ._logger import logger
from .utils import _check_fp16


try: try:
import fitlog import fitlog
except: except:
pass pass


try:
from apex import amp
except:
amp = None



class Callback(object): class Callback(object):
""" """
@@ -269,14 +275,6 @@ class Callback(object):
:return: :return:
""" """
pass pass

def on_validation(self):
"""
如果Trainer中设置了验证,则会在每次需要验证时调用该函数

:return:
"""
pass
def on_epoch_end(self): def on_epoch_end(self):
""" """
@@ -470,7 +468,7 @@ class GradientClipCallback(Callback):
if self.step%self.update_every==0: if self.step%self.update_every==0:
if self.parameters is None: if self.parameters is None:
if getattr(self.trainer, 'fp16', ''): if getattr(self.trainer, 'fp16', ''):
from apex import amp
_check_fp16()
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) self.clip_fun(amp.master_params(self.optimizer), self.clip_value)
self.clip_fun(self.model.parameters(), self.clip_value) self.clip_fun(self.model.parameters(), self.clip_value)
else: else:
@@ -1036,27 +1034,23 @@ class EchoCallback(Callback):
return super(EchoCallback, self).__getattribute__(item) return super(EchoCallback, self).__getattribute__(item)




class TesterCallback(Callback):
class _TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__()
super(_TesterCallback, self).__init__()
if hasattr(model, 'module'): if hasattr(model, 'module'):
# for data parallel model # for data parallel model
model = model.module model = model.module
self.tester = Tester(data, model, self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size, metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0) num_workers=num_workers, verbose=0)
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = True
if metric_key is not None: if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
else: else:
self.metric_key = None self.metric_key = None
self.increase_better = True
self.score = None self.score = None


def on_validation(self):
def on_valid_begin(self):
cur_score = self.tester.test() cur_score = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps, self.epoch, self.n_epochs, self.step, self.n_steps,
@@ -1067,17 +1061,28 @@ class TesterCallback(Callback):
self.score = cur_score self.score = cur_score
return cur_score, is_better return cur_score, is_better


def _get_score(self, metric_dict, key):
@staticmethod
def _get_score(metric_dict, key):
for metric in metric_dict.items(): for metric in metric_dict.items():
if key in metric: if key in metric:
return metric[key] return metric[key]
return None return None


@staticmethod
def _parse_metric_key(metric_key):
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
increase_better = False if metric_key[0] == "-" else True
metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
return metric_key, increase_better

def compare_better(self, a): def compare_better(self, a):
if self.score is None: if self.score is None:
return True return True
if self.metric_key is None: if self.metric_key is None:
self.metric_key = list(list(self.score.values())[0].keys())[0]
metric_key = list(list(self.score.values())[0].keys())[0]
self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
k = self.metric_key k = self.metric_key
score = self._get_score(self.score, k) score = self._get_score(self.score, k)
new_score = self._get_score(a, k) new_score = self._get_score(a, k)
@@ -1087,7 +1092,3 @@ class TesterCallback(Callback):
return score <= new_score return score <= new_score
else: else:
return score >= new_score return score >= new_score

def on_train_end(self):
self.logger.info('Evaluate on training ends.')
self.on_validation()

+ 115
- 24
fastNLP/core/dist_trainer.py View File

@@ -17,21 +17,30 @@ from tqdm import tqdm


from ._logger import logger from ._logger import logger
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback
from .callback import DistCallbackManager, CallbackException, _TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device
from .utils import _check_fp16


try:
from apex import amp
except:
amp = None


__all__ = [ __all__ = [
'get_local_rank', 'get_local_rank',
'DistTrainer', 'DistTrainer',
] ]



def get_local_rank(): def get_local_rank():
"""
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数
"""
if 'LOCAL_RANK' in os.environ: if 'LOCAL_RANK' in os.environ:
return int(os.environ['LOCAL_RANK']) return int(os.environ['LOCAL_RANK'])
from argparse import ArgumentParser from argparse import ArgumentParser
@@ -46,7 +55,10 @@ def get_local_rank():


class DistTrainer(): class DistTrainer():
""" """
Distributed Trainer that support distributed and mixed precision training
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。

Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前,
请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等)
""" """
def __init__(self, train_data, model, optimizer=None, loss=None, def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None, callbacks_all=None, callbacks_master=None,
@@ -55,8 +67,43 @@ class DistTrainer():
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None, device='auto', save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):
fp16='', backend=None, init_method=None, use_tqdm=True):
"""


:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
:param nn.modules model: 待训练的模型
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>`
:param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。
可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>`
:param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。
:param int n_epochs: 需要优化迭代多少次。
:param num_workers: int, 有多少个线程来进行数据pad处理。
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None,
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标,
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
:param int save_every: 多少个step保存一次模型,如果为-1,则每个epoch结束保存一次。仅在传入save_path时有效。
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param str device: 指定 device,可以是 gpu,cpu 或 auto
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。
:param backend: 指定分布式的backend,详情参考 pytorch 文档
:param init_method 指定分布式的初始化方法,详情参考 pytorch 文档
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
"""
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto': if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -94,7 +141,9 @@ class DistTrainer():
self.callback_manager = DistCallbackManager( self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all, env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master) callbacks_master=callbacks_master)
self.test_manager = DistCallbackManager(env={'trainer': self})
self.metric_key = metric_key self.metric_key = metric_key
self.use_tqdm = use_tqdm


model.to(self.device) model.to(self.device)
optimizer = self._get_optimizer(optimizer) optimizer = self._get_optimizer(optimizer)
@@ -102,11 +151,7 @@ class DistTrainer():
# init fp16, must before DataParallel init # init fp16, must before DataParallel init
if len(self.fp16): if len(self.fp16):
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']"
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
_check_fp16()
assert device == 'cuda', "Amp requires cuda device" assert device == 'cuda', "Amp requires cuda device"
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)


@@ -121,14 +166,15 @@ class DistTrainer():
self.optimizer = optimizer self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data) self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self._get_data_iter(self.train_data) self.data_iterator = self._get_data_iter(self.train_data)
self.batch_size = self.world_size * self.batch_size_per_gpu
self.n_steps = self._get_n_steps() self.n_steps = self._get_n_steps()


# for evaluation, only run eval on master proc # for evaluation, only run eval on master proc
if dev_data and metrics: if dev_data and metrics:
cb = TesterCallback(
cb = _TesterCallback(
dev_data, model, metrics, dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers) batch_size=batch_size_per_gpu, num_workers=num_workers)
self.callback_manager.add_callback([cb], master=True)
self.test_manager.add_callback([cb], master=True)


# Setup logging # Setup logging
dist.barrier() dist.barrier()
@@ -178,9 +224,27 @@ class DistTrainer():


@property @property
def is_master(self): def is_master(self):
"""是否是主进程"""
return self.rank == 0 return self.rank == 0


def train(self, on_exception='auto'):
def train(self, load_best_model=True, on_exception='auto'):
"""
使用该函数使Trainer开始训练。

:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出;
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception.
:return dict: 返回一个字典类型的数据,
内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称,
第二层的key为具体的Metric
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值

"""
try: try:
self.logger.info("###### Training epochs started ######") self.logger.info("###### Training epochs started ######")
self.logger.info('Total epochs: %d'% self.n_epochs) self.logger.info('Total epochs: %d'% self.n_epochs)
@@ -222,17 +286,22 @@ class DistTrainer():
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)
self.logger.info("###### Train finished ######") self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
return results
if load_best_model:
self.load_check_point('best_{}'.format(self.metric_key))
finally: finally:
self.close()
pass

return results


def _train(self): def _train(self):
if self.fp16:
# skip check, done in __init__()
from apex import amp
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm

self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
leave=False, dynamic_ncols=True, disable=not self.is_master) leave=False, dynamic_ncols=True, disable=not self.is_master)
pbar = self.pbar pbar = self.pbar
avg_loss = 0 avg_loss = 0
@@ -292,8 +361,8 @@ class DistTrainer():
if self.validate_every < 0: if self.validate_every < 0:
self._do_validation() self._do_validation()


if self.save_every < 0 and self.cp_save_path:
self.save_check_point()
if self.save_every < 0 and self.cp_save_path:
self.save_check_point()
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
@@ -327,22 +396,35 @@ class DistTrainer():
loss = self.losser(predict, truth) loss = self.losser(predict, truth)
if self.update_every > 1: if self.update_every > 1:
loss = loss / self.update_every loss = loss / self.update_every
return loss.mean()
if loss.dim() > 0:
loss = loss.mean()
return loss


def save_check_point(self, only_params=False):
def save_check_point(self, name=None, only_params=False):
"""保存当前模型"""
# only master save models # only master save models
if self.is_master: if self.is_master:
if name is None:
name = 'checkpoint-{}.bin'.format(self.step)
os.makedirs(self.cp_save_path, exist_ok=True) os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step))
path = os.path.join(self.cp_save_path, name)
self.logger.info("Save checkpoint to {}".format(path)) self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.model.module model_to_save = self.model.module
if only_params: if only_params:
model_to_save = model_to_save.state_dict() model_to_save = model_to_save.state_dict()
torch.save(model_to_save, path) torch.save(model_to_save, path)


def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path)
model_load = torch.load(path)
if not isinstance(model_load, dict):
model_load = model_load.state_dict()
self.model.load_state_dict(model_load)

def _do_validation(self): def _do_validation(self):
self.callback_manager.on_valid_begin() self.callback_manager.on_valid_begin()
eval_res = self.callback_manager.on_validation()
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res)) eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res): if len(eval_res):
eval_res, is_better = list(zip(*eval_res)) eval_res, is_better = list(zip(*eval_res))
@@ -350,7 +432,16 @@ class DistTrainer():
eval_res, is_better = None, None eval_res, is_better = None, None
self.callback_manager.on_valid_end( self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better) eval_res, self.metric_key, self.optimizer, is_better)

# save better model
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
name = 'best_{}'.format(self.metric_key)
self.save_check_point(name)
break
dist.barrier() dist.barrier()


def close(self): def close(self):
"""关闭Trainer,销毁进程"""
dist.destroy_process_group() dist.destroy_process_group()

+ 1
- 0
fastNLP/core/trainer.py View File

@@ -842,6 +842,7 @@ class Trainer(object):


@property @property
def is_master(self): def is_master(self):
"""是否是主进程"""
return True return True


DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2


+ 11
- 0
fastNLP/core/utils.py View File

@@ -19,6 +19,10 @@ import torch.nn as nn
from typing import List from typing import List
from ._logger import logger from ._logger import logger
from prettytable import PrettyTable from prettytable import PrettyTable
try:
from apex import amp
except:
amp = None


_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs']) 'varargs'])
@@ -805,3 +809,10 @@ def sub_column(string: str, c: int, c_size: int, title: str) -> str:
if len(string) > avg: if len(string) > avg:
string = string[:(avg - 3)] + "..." string = string[:(avg - 3)] + "..."
return string return string


def _check_fp16():
if amp is None:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
if not torch.backends.cudnn.enabled:
raise RuntimeError("Amp requires cudnn backend to be enabled.")

Loading…
Cancel
Save