Browse Source

增加DistTrainer中的batch_per_epoch属性

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
972185dc6c
3 changed files with 20 additions and 15 deletions
  1. +16
    -11
      fastNLP/core/dist_trainer.py
  2. +1
    -1
      fastNLP/io/data_bundle.py
  3. +3
    -3
      fastNLP/io/pipe/pipe.py

+ 16
- 11
fastNLP/core/dist_trainer.py View File

@@ -73,7 +73,7 @@ class DistTrainer:
r""" r"""


:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
:param nn.modules model: 待训练的模型
:param nn.modules, DDP model: 待训练的模型
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
@@ -146,7 +146,6 @@ class DistTrainer:
self.losser = _prepare_losser(loss) self.losser = _prepare_losser(loss)
self.fp16 = fp16 self.fp16 = fp16
self.local_rank = get_local_rank() self.local_rank = get_local_rank()
self._forward_func = model.forward
self.callback_manager = DistCallbackManager( self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all, env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master) callbacks_master=callbacks_master)
@@ -154,8 +153,6 @@ class DistTrainer:
self.metric_key = metric_key self.metric_key = metric_key
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm


model.to(self.device)

# init fp16, must before DataParallel init # init fp16, must before DataParallel init
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
self.auto_cast = autocast self.auto_cast = autocast
@@ -170,15 +167,22 @@ class DistTrainer:
self.set_grad_to_none = kwargs.get('set_grad_to_none', False) self.set_grad_to_none = kwargs.get('set_grad_to_none', False)


# init DataParallel # init DataParallel
if parse_version(torch.__version__)>=parse_version('1.1'):
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=kwargs.get('find_unused_parameters', False))
if isinstance(model, DDP):
self.ddp_model = model
else: else:
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
if parse_version(torch.__version__)>=parse_version('1.1'):
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=kwargs.get('find_unused_parameters', False))
else:
self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
self.model = self.ddp_model.module self.model = self.ddp_model.module


self._forward_func = self.model.forward
self.model.to(self.device)


optimizer = self._get_optimizer(optimizer) optimizer = self._get_optimizer(optimizer)
self.optimizer = optimizer self.optimizer = optimizer
if isinstance(self.train_data, DataSet): if isinstance(self.train_data, DataSet):
@@ -207,7 +211,7 @@ class DistTrainer:
# for evaluation, only run eval on master proc # for evaluation, only run eval on master proc
if dev_data and metrics: if dev_data and metrics:
cb = _TesterCallback( cb = _TesterCallback(
dev_data, model, metrics,
dev_data, self.model, metrics,
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
use_tqdm=self.test_use_tqdm) use_tqdm=self.test_use_tqdm)
self.test_manager.add_callback([cb], master=True) self.test_manager.add_callback([cb], master=True)
@@ -343,6 +347,7 @@ class DistTrainer:
avg_loss = 0 avg_loss = 0
data_iterator = self.data_iterator data_iterator = self.data_iterator
self.ddp_model.zero_grad() self.ddp_model.zero_grad()
self.batch_per_epoch = self.data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))


+ 1
- 1
fastNLP/io/data_bundle.py View File

@@ -32,7 +32,7 @@ class DataBundle:
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在 :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
使用Pipe处理数据的时候遇到问题。
使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入
""" """
self.vocabs = vocabs or {} self.vocabs = vocabs or {}
self.datasets = datasets or {} self.datasets = datasets or {}


+ 3
- 3
fastNLP/io/pipe/pipe.py View File

@@ -27,15 +27,15 @@ class Pipe:
对输入的DataBundle进行处理,然后返回该DataBundle。 对输入的DataBundle进行处理,然后返回该DataBundle。


:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
:return:
:return: DataBundle
""" """
raise NotImplementedError raise NotImplementedError


def process_from_file(self, paths) -> DataBundle:
def process_from_file(self, paths: str) -> DataBundle:
r""" r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`


:param paths:
:param str paths:
:return: DataBundle :return: DataBundle
""" """
raise NotImplementedError raise NotImplementedError

Loading…
Cancel
Save