Browse Source

将total_batches改名为n_batches,以和n_epochs对应;增加n_batches作为Trainer的初始化参数

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
057806de7d
12 changed files with 46 additions and 32 deletions
  1. +3
    -3
      fastNLP/core/callbacks/progress_callback.py
  2. +1
    -1
      fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
  3. +4
    -3
      fastNLP/core/controllers/evaluator.py
  4. +18
    -10
      fastNLP/core/controllers/trainer.py
  5. +2
    -2
      fastNLP/core/controllers/utils/state.py
  6. +5
    -4
      fastNLP/core/dataset/dataset.py
  7. +1
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py
  8. +1
    -1
      fastNLP/core/utils/rich_progress.py
  9. +4
    -3
      fastNLP/io/file_reader.py
  10. +5
    -2
      fastNLP/io/loader/conll.py
  11. +1
    -1
      tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py
  12. +1
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py

+ 3
- 3
fastNLP/core/callbacks/progress_callback.py View File

@@ -79,7 +79,7 @@ class RichCallback(ProgressCallback):

def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.total_batches+1e-6))
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6))

def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
@@ -190,7 +190,7 @@ class RawTextCallback(ProgressCallback):
self.loss = 0
text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \
f'loss:{round(loss, self.loss_round_ndigit)}, ' \
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.'
f'finished {round(trainer.global_forward_batches/trainer.n_batches*100, 2)}%.'
logger.info(text)

def on_evaluate_end(self, trainer, results):
@@ -251,7 +251,7 @@ class TqdmCallback(ProgressCallback):
def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6))
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6))

def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)


+ 1
- 1
fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py View File

@@ -41,7 +41,7 @@ class TorchWarmupCallback(Callback):
return max((progress - 1.) / (self.warmup - 1.), 0.)

def on_train_begin(self, trainer):
self.t_steps = trainer.total_batches
self.t_steps = trainer.n_batches
if self.warmup >1:
self.warmup = self.warmup / self.t_steps
self.t_steps = max(2, self.t_steps) # 不能小于2


+ 4
- 3
fastNLP/core/controllers/evaluator.py View File

@@ -460,14 +460,15 @@ class _MetricsWrapper:
for metric in self._metrics:
args = []
if not isinstance(batch, dict):
logger.warning_once(
logger.rank_zero_warning(
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
f"the output of model to update metric.", once=True)
else:
args.append(batch)
if not isinstance(outputs, dict):
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
f" return a dict from your model or use `output_mapping` to convert it into dict "
f"type.")
if isinstance(metric, Metric):
# 这样在 auto_param_call 报错的时候才清晰。
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)


+ 18
- 10
fastNLP/core/controllers/trainer.py View File

@@ -110,7 +110,7 @@ class Trainer(TrainerEventTrigger):

对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。

:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None;
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``,
@@ -237,6 +237,8 @@ class Trainer(TrainerEventTrigger):

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。

:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;

.. note::
@@ -356,6 +358,7 @@ class Trainer(TrainerEventTrigger):
fp16: bool = False,
monitor: Union[str, Callable] = None,
larger_better: bool = True,
n_batches: int = -1,
marker: Optional[str] = None,
**kwargs
):
@@ -426,6 +429,7 @@ class Trainer(TrainerEventTrigger):
model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps,
fp16=fp16,
n_batches=n_batches,
marker=marker,
**kwargs
)
@@ -444,12 +448,12 @@ class Trainer(TrainerEventTrigger):
# 初始化 state,包括提供给用户的接口和我们自己使用的接口;
self.state = State()
self.trainer_state = TrainerState(
n_epochs=n_epochs,
n_epochs=n_epochs if n_batches!=-1 else None,
cur_epoch_idx=0,
global_forward_batches=0,
batch_idx_in_epoch=0,
num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化;
total_batches=None
n_batches=n_batches
)

if metrics is None and evaluate_dataloaders is not None:
@@ -598,14 +602,18 @@ class Trainer(TrainerEventTrigger):
self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch)

self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs
if self.n_batches == -1:
self.n_batches = self.num_batches_per_epoch * self.n_epochs
else:
self.n_epochs = (self.n_batches+self.num_batches_per_epoch-1)//self.num_batches_per_epoch

self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch

try:
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad()
while self.cur_epoch_idx < self.n_epochs:
while self.cur_epoch_idx < self.n_epochs and self.global_forward_batches < self.n_batches:
# 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
self.driver.set_model_mode("train")
@@ -1367,15 +1375,15 @@ class Trainer(TrainerEventTrigger):
self.trainer_state.num_batches_per_epoch = num_batches_per_epoch

@property
def total_batches(self) -> int:
def n_batches(self) -> int:
r"""
:return: 返回整体的训练中实际会训练多少个 batch 的数据;
"""
return self.trainer_state.total_batches
return self.trainer_state.n_batches

@total_batches.setter
def total_batches(self, total_batches: int):
self.trainer_state.total_batches = total_batches
@n_batches.setter
def n_batches(self, n_batches: int):
self.trainer_state.n_batches = n_batches

""" driver property """



+ 2
- 2
fastNLP/core/controllers/utils/state.py View File

@@ -50,7 +50,7 @@ class TrainerState:
:param global_forward_batches: 当前模型总共 forward 了多少个 step;
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
:param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs;
"""
n_epochs: Optional[int] = None # 无论如何重新算

@@ -61,7 +61,7 @@ class TrainerState:

num_batches_per_epoch: Optional[int] = None # 无论如何重新算

total_batches: Optional[int] = None # 无论如何重新算
n_batches: Optional[int] = None # 无论如何重新算

def state_dict(self) -> Dict:
r"""


+ 5
- 4
fastNLP/core/dataset/dataset.py View File

@@ -156,7 +156,6 @@ import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from types import LambdaType
from subprocess import DEVNULL
import sys
import time

@@ -170,6 +169,7 @@ from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
from ..log import logger
from fastNLP.core.utils.dummy_class import DummyClass
from ..utils.utils import _get_fun_msg


progress_bars = {
@@ -780,8 +780,8 @@ class DataSet:
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
progress_bar=progress_bar)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")
if not isinstance(apply_out[0], Mapping):
raise Exception(f"The result of func:{_get_fun_msg(func)} is not a dict, but of type {type(apply_out[0])}")

for key, value in apply_out[0].items():
results[key] = [value]
@@ -789,7 +789,8 @@ class DataSet:
try:
for idx, per_out in enumerate(apply_out[1:]):
if len(set(results.keys()) - set(per_out.keys())):
raise ApplyResultException("apply results have different fields", idx + 1)
raise ApplyResultException(f"Apply results have different fields:{set(results.keys())} and "
f"{set(per_out.keys())}", idx + 1)
for key, value in per_out.items():
results[key].append(value)



+ 1
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -169,7 +169,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用
"""
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
drop_last: bool = False, seed: int = None, **kwargs):
super().__init__()

self.dataset = dataset


+ 1
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -120,7 +120,7 @@ class FRichProgress(Progress, metaclass=Singleton):

def add_task(
self,
description: str,
description: str = 'Progress',
start: bool = True,
total: float = 100.0,
completed: int = 0,


+ 4
- 3
fastNLP/io/file_reader.py View File

@@ -7,7 +7,7 @@ __all__ = []
import json
import csv

# from ..core import log
from ..core import logger


def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
@@ -81,7 +81,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True):
yield line_idx, _res


def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True, drophash=True):
r"""
Construct a generator to read conll items.

@@ -91,6 +91,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True
:param drophash: 是否丢掉以 # 开头的 line 。
:return: generator, every time yield (line number, conll item)
"""

@@ -121,7 +122,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
sample = []
continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
elif line.startswith('#') and drophash:
continue
else:
sample.append(line.split(sep)) if sep else sample.append(line.split())


+ 5
- 2
fastNLP/io/loader/conll.py View File

@@ -52,13 +52,14 @@ class ConllLoader(Loader):

"""
def __init__(self, headers, sep=None, indexes=None, dropna=True):
def __init__(self, headers, sep=None, indexes=None, dropna=True, drophash=True):
r"""
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list sep: 指定分隔符,默认为制表符
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
:param bool drophashtag: 是否忽略以 ``#`` 开头的句子。
"""
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
@@ -66,6 +67,7 @@ class ConllLoader(Loader):
'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers
self.dropna = dropna
self.drophash = drophash
self.sep=sep
if indexes is None:
self.indexes = list(range(len(self.headers)))
@@ -82,7 +84,8 @@ class ConllLoader(Loader):
:return: DataSet
"""
ds = DataSet()
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna):
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna,
drophash=self.drophash):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds


+ 1
- 1
tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py View File

@@ -32,4 +32,4 @@ def test_torch_warmup_callback(warmup, schedule, accumulation_steps):
elif schedule == 'constant':
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr'])

assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1
assert len(r_callback.lrs)<=trainer.n_batches//accumulation_steps+1

+ 1
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -55,4 +55,4 @@ class RecordAccumulationStepsCallback_Torch(Callback):

def on_train_end(self, trainer):
print(f"\n equal num: {self.equal}.\n")
print(f"\ntotal_batch_num: {trainer.total_batches}.\n")
print(f"\ntotal_batch_num: {trainer.n_batches}.\n")

Loading…
Cancel
Save