Browse Source

增加部分关于evaluate_batch_step_fn的说明

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
6f7bbfabca
3 changed files with 38 additions and 57 deletions
  1. +1
    -1
      fastNLP/core/callbacks/more_evaluate_callback.py
  2. +25
    -35
      fastNLP/core/controllers/evaluator.py
  3. +12
    -21
      fastNLP/core/controllers/trainer.py

+ 1
- 1
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -108,7 +108,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
'metrics': self.metrics,
'driver': self.kwargs.get('driver', trainer.driver),
'device': self.kwargs.get('device', trainer.device),
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn),
'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn),
'evaluate_fn': self.evaluate_fn,
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping),
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping),


+ 25
- 35
fastNLP/core/controllers/evaluator.py View File

@@ -30,22 +30,12 @@ class Evaluator:
driver: Driver
_evaluate_batch_loop: Loop

def __init__(
self,
model,
dataloaders,
metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
fp16: bool = False,
verbose: int = 1,
**kwargs
):
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None,
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False,
fp16: bool = False, verbose: int = 1, **kwargs):
"""

:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。
@@ -54,13 +44,13 @@ class Evaluator:
metric ,torchmetrics,allennlpmetrics等。
:param driver: 使用 driver 。
:param device: 使用的设备。
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的
batch_step_fn 函数。
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对
model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
@@ -96,9 +86,9 @@ class Evaluator:
self.device = device
self.verbose = verbose

if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn
if evaluate_batch_step_fn is not None:
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
self.evaluate_batch_step_fn = evaluate_batch_step_fn

self.input_mapping = input_mapping
self.output_mapping = output_mapping
@@ -106,7 +96,7 @@ class Evaluator:
if not isinstance(dataloaders, dict):
dataloaders = {None: dataloaders}

self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)

self.driver.setup()
self.driver.barrier()
@@ -235,8 +225,8 @@ class Evaluator:

@evaluate_batch_loop.setter
def evaluate_batch_loop(self, loop: Loop):
if self.batch_step_fn is not None:
logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
if self.evaluate_batch_step_fn is not None:
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
"when the `evaluate_batch_loop` is also customized.")
self._evaluate_batch_loop = loop

@@ -249,15 +239,15 @@ class Evaluator:
"""
self.metrics_wrapper.reset()

def update(self, *args, **kwargs):
def update(self, batch, outputs):
"""
调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。

:param args:
:param kwargs:
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。
:return:
"""
self.metrics_wrapper.update(*args, **kwargs)
self.metrics_wrapper.update(batch, outputs)

def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict:
"""
@@ -271,7 +261,7 @@ class Evaluator:
@property
def metrics_wrapper(self):
"""
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper
进行操作。

@@ -283,11 +273,11 @@ class Evaluator:

def evaluate_step(self, batch):
"""
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再
返回。

:param batch:
:return:
:param batch: {evaluate_fn} 函数支持的输入类型
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。
"""
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs)


+ 12
- 21
fastNLP/core/controllers/trainer.py View File

@@ -83,10 +83,10 @@ class Trainer(TrainerEventTrigger):
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None;
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和
`batch`;默认为 None;
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
两个参数必须为 `evaluator` 和 `batch`;默认为 None;
:param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以
参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,
则使用模型默认的前向传播函数。
@@ -136,9 +136,9 @@ class Trainer(TrainerEventTrigger):
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
"""
self.model = model
self.marker = marker
@@ -268,21 +268,12 @@ class Trainer(TrainerEventTrigger):
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(
model=model,
dataloaders=evaluate_dataloaders,
metrics=metrics,
driver=self.driver,
device=device,
batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
input_mapping=input_mapping,
output_mapping=output_mapping,
fp16=fp16,
verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
progress_bar=progress_bar
)
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=input_mapping,
output_mapping=output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
progress_bar=progress_bar)

if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")


Loading…
Cancel
Save