Browse Source

将total_batches改名为n_batches,以和n_epochs对应;增加n_batches作为Trainer的初始化参数

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
057806de7d
12 changed files with 46 additions and 32 deletions
  1. +3
    -3
      fastNLP/core/callbacks/progress_callback.py
  2. +1
    -1
      fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
  3. +4
    -3
      fastNLP/core/controllers/evaluator.py
  4. +18
    -10
      fastNLP/core/controllers/trainer.py
  5. +2
    -2
      fastNLP/core/controllers/utils/state.py
  6. +5
    -4
      fastNLP/core/dataset/dataset.py
  7. +1
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py
  8. +1
    -1
      fastNLP/core/utils/rich_progress.py
  9. +4
    -3
      fastNLP/io/file_reader.py
  10. +5
    -2
      fastNLP/io/loader/conll.py
  11. +1
    -1
      tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py
  12. +1
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py

+ 3
- 3
fastNLP/core/callbacks/progress_callback.py View File

@@ -79,7 +79,7 @@ class RichCallback(ProgressCallback):


def on_train_begin(self, trainer): def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.total_batches+1e-6))
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6))


def on_train_epoch_begin(self, trainer): def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
@@ -190,7 +190,7 @@ class RawTextCallback(ProgressCallback):
self.loss = 0 self.loss = 0
text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \ text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \
f'loss:{round(loss, self.loss_round_ndigit)}, ' \ f'loss:{round(loss, self.loss_round_ndigit)}, ' \
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.'
f'finished {round(trainer.global_forward_batches/trainer.n_batches*100, 2)}%.'
logger.info(text) logger.info(text)


def on_evaluate_end(self, trainer, results): def on_evaluate_end(self, trainer, results):
@@ -251,7 +251,7 @@ class TqdmCallback(ProgressCallback):
def on_train_begin(self, trainer): def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6))
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6))


def on_train_epoch_begin(self, trainer): def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)


+ 1
- 1
fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py View File

@@ -41,7 +41,7 @@ class TorchWarmupCallback(Callback):
return max((progress - 1.) / (self.warmup - 1.), 0.) return max((progress - 1.) / (self.warmup - 1.), 0.)


def on_train_begin(self, trainer): def on_train_begin(self, trainer):
self.t_steps = trainer.total_batches
self.t_steps = trainer.n_batches
if self.warmup >1: if self.warmup >1:
self.warmup = self.warmup / self.t_steps self.warmup = self.warmup / self.t_steps
self.t_steps = max(2, self.t_steps) # 不能小于2 self.t_steps = max(2, self.t_steps) # 不能小于2


+ 4
- 3
fastNLP/core/controllers/evaluator.py View File

@@ -460,14 +460,15 @@ class _MetricsWrapper:
for metric in self._metrics: for metric in self._metrics:
args = [] args = []
if not isinstance(batch, dict): if not isinstance(batch, dict):
logger.warning_once(
logger.rank_zero_warning(
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
f"the output of model to update metric.", once=True)
else: else:
args.append(batch) args.append(batch)
if not isinstance(outputs, dict): if not isinstance(outputs, dict):
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
f" return a dict from your model or use `output_mapping` to convert it into dict "
f"type.")
if isinstance(metric, Metric): if isinstance(metric, Metric):
# 这样在 auto_param_call 报错的时候才清晰。 # 这样在 auto_param_call 报错的时候才清晰。
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)


+ 18
- 10
fastNLP/core/controllers/trainer.py View File

@@ -110,7 +110,7 @@ class Trainer(TrainerEventTrigger):


对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。


:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None; 为 None;
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``,
@@ -237,6 +237,8 @@ class Trainer(TrainerEventTrigger):


注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;


:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。

:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;


.. note:: .. note::
@@ -356,6 +358,7 @@ class Trainer(TrainerEventTrigger):
fp16: bool = False, fp16: bool = False,
monitor: Union[str, Callable] = None, monitor: Union[str, Callable] = None,
larger_better: bool = True, larger_better: bool = True,
n_batches: int = -1,
marker: Optional[str] = None, marker: Optional[str] = None,
**kwargs **kwargs
): ):
@@ -426,6 +429,7 @@ class Trainer(TrainerEventTrigger):
model_wo_auto_param_call=model_wo_auto_param_call, model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
fp16=fp16, fp16=fp16,
n_batches=n_batches,
marker=marker, marker=marker,
**kwargs **kwargs
) )
@@ -444,12 +448,12 @@ class Trainer(TrainerEventTrigger):
# 初始化 state,包括提供给用户的接口和我们自己使用的接口; # 初始化 state,包括提供给用户的接口和我们自己使用的接口;
self.state = State() self.state = State()
self.trainer_state = TrainerState( self.trainer_state = TrainerState(
n_epochs=n_epochs,
n_epochs=n_epochs if n_batches!=-1 else None,
cur_epoch_idx=0, cur_epoch_idx=0,
global_forward_batches=0, global_forward_batches=0,
batch_idx_in_epoch=0, batch_idx_in_epoch=0,
num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化; num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化;
total_batches=None
n_batches=n_batches
) )


if metrics is None and evaluate_dataloaders is not None: if metrics is None and evaluate_dataloaders is not None:
@@ -598,14 +602,18 @@ class Trainer(TrainerEventTrigger):
self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch) self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch)


self.num_batches_per_epoch = len(self.dataloader) self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs
if self.n_batches == -1:
self.n_batches = self.num_batches_per_epoch * self.n_epochs
else:
self.n_epochs = (self.n_batches+self.num_batches_per_epoch-1)//self.num_batches_per_epoch

self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch


try: try:
self.on_train_begin() self.on_train_begin()
self.driver.barrier() self.driver.barrier()
self.driver.zero_grad() self.driver.zero_grad()
while self.cur_epoch_idx < self.n_epochs:
while self.cur_epoch_idx < self.n_epochs and self.global_forward_batches < self.n_batches:
# 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
self.driver.set_model_mode("train") self.driver.set_model_mode("train")
@@ -1367,15 +1375,15 @@ class Trainer(TrainerEventTrigger):
self.trainer_state.num_batches_per_epoch = num_batches_per_epoch self.trainer_state.num_batches_per_epoch = num_batches_per_epoch


@property @property
def total_batches(self) -> int:
def n_batches(self) -> int:
r""" r"""
:return: 返回整体的训练中实际会训练多少个 batch 的数据; :return: 返回整体的训练中实际会训练多少个 batch 的数据;
""" """
return self.trainer_state.total_batches
return self.trainer_state.n_batches


@total_batches.setter
def total_batches(self, total_batches: int):
self.trainer_state.total_batches = total_batches
@n_batches.setter
def n_batches(self, n_batches: int):
self.trainer_state.n_batches = n_batches


""" driver property """ """ driver property """




+ 2
- 2
fastNLP/core/controllers/utils/state.py View File

@@ -50,7 +50,7 @@ class TrainerState:
:param global_forward_batches: 当前模型总共 forward 了多少个 step; :param global_forward_batches: 当前模型总共 forward 了多少个 step;
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
:param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs;
""" """
n_epochs: Optional[int] = None # 无论如何重新算 n_epochs: Optional[int] = None # 无论如何重新算


@@ -61,7 +61,7 @@ class TrainerState:


num_batches_per_epoch: Optional[int] = None # 无论如何重新算 num_batches_per_epoch: Optional[int] = None # 无论如何重新算


total_batches: Optional[int] = None # 无论如何重新算
n_batches: Optional[int] = None # 无论如何重新算


def state_dict(self) -> Dict: def state_dict(self) -> Dict:
r""" r"""


+ 5
- 4
fastNLP/core/dataset/dataset.py View File

@@ -156,7 +156,6 @@ import _pickle as pickle
from copy import deepcopy from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from types import LambdaType from types import LambdaType
from subprocess import DEVNULL
import sys import sys
import time import time


@@ -170,6 +169,7 @@ from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
from ..log import logger from ..log import logger
from fastNLP.core.utils.dummy_class import DummyClass from fastNLP.core.utils.dummy_class import DummyClass
from ..utils.utils import _get_fun_msg




progress_bars = { progress_bars = {
@@ -780,8 +780,8 @@ class DataSet:
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
progress_bar=progress_bar) progress_bar=progress_bar)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")
if not isinstance(apply_out[0], Mapping):
raise Exception(f"The result of func:{_get_fun_msg(func)} is not a dict, but of type {type(apply_out[0])}")


for key, value in apply_out[0].items(): for key, value in apply_out[0].items():
results[key] = [value] results[key] = [value]
@@ -789,7 +789,8 @@ class DataSet:
try: try:
for idx, per_out in enumerate(apply_out[1:]): for idx, per_out in enumerate(apply_out[1:]):
if len(set(results.keys()) - set(per_out.keys())): if len(set(results.keys()) - set(per_out.keys())):
raise ApplyResultException("apply results have different fields", idx + 1)
raise ApplyResultException(f"Apply results have different fields:{set(results.keys())} and "
f"{set(per_out.keys())}", idx + 1)
for key, value in per_out.items(): for key, value in per_out.items():
results[key].append(value) results[key].append(value)




+ 1
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -169,7 +169,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用 :param kwargs: fastNLP 保留使用
""" """
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
drop_last: bool = False, seed: int = None, **kwargs):
super().__init__() super().__init__()


self.dataset = dataset self.dataset = dataset


+ 1
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -120,7 +120,7 @@ class FRichProgress(Progress, metaclass=Singleton):


def add_task( def add_task(
self, self,
description: str,
description: str = 'Progress',
start: bool = True, start: bool = True,
total: float = 100.0, total: float = 100.0,
completed: int = 0, completed: int = 0,


+ 4
- 3
fastNLP/io/file_reader.py View File

@@ -7,7 +7,7 @@ __all__ = []
import json import json
import csv import csv


# from ..core import log
from ..core import logger




def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
@@ -81,7 +81,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True):
yield line_idx, _res yield line_idx, _res




def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True, drophash=True):
r""" r"""
Construct a generator to read conll items. Construct a generator to read conll items.


@@ -91,6 +91,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data, :param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
:param drophash: 是否丢掉以 # 开头的 line 。
:return: generator, every time yield (line number, conll item) :return: generator, every time yield (line number, conll item)
""" """


@@ -121,7 +122,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
sample = [] sample = []
continue continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
elif line.startswith('#') and drophash:
continue continue
else: else:
sample.append(line.split(sep)) if sep else sample.append(line.split()) sample.append(line.split(sep)) if sep else sample.append(line.split())


+ 5
- 2
fastNLP/io/loader/conll.py View File

@@ -52,13 +52,14 @@ class ConllLoader(Loader):


""" """
def __init__(self, headers, sep=None, indexes=None, dropna=True):
def __init__(self, headers, sep=None, indexes=None, dropna=True, drophash=True):
r""" r"""
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list sep: 指定分隔符,默认为制表符 :param list sep: 指定分隔符,默认为制表符
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
:param bool drophashtag: 是否忽略以 ``#`` 开头的句子。
""" """
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
@@ -66,6 +67,7 @@ class ConllLoader(Loader):
'invalid headers: {}, should be list of strings'.format(headers)) 'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers self.headers = headers
self.dropna = dropna self.dropna = dropna
self.drophash = drophash
self.sep=sep self.sep=sep
if indexes is None: if indexes is None:
self.indexes = list(range(len(self.headers))) self.indexes = list(range(len(self.headers)))
@@ -82,7 +84,8 @@ class ConllLoader(Loader):
:return: DataSet :return: DataSet
""" """
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna):
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna,
drophash=self.drophash):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds


+ 1
- 1
tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py View File

@@ -32,4 +32,4 @@ def test_torch_warmup_callback(warmup, schedule, accumulation_steps):
elif schedule == 'constant': elif schedule == 'constant':
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr']) assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr'])


assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1
assert len(r_callback.lrs)<=trainer.n_batches//accumulation_steps+1

+ 1
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -55,4 +55,4 @@ class RecordAccumulationStepsCallback_Torch(Callback):


def on_train_end(self, trainer): def on_train_end(self, trainer):
print(f"\n equal num: {self.equal}.\n") print(f"\n equal num: {self.equal}.\n")
print(f"\ntotal_batch_num: {trainer.total_batches}.\n")
print(f"\ntotal_batch_num: {trainer.n_batches}.\n")

Loading…
Cancel
Save