diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index dbb6505f..6c015bdf 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -108,7 +108,7 @@ class MoreEvaluateCallback(HasMonitorCallback): 'metrics': self.metrics, 'driver': self.kwargs.get('driver', trainer.driver), 'device': self.kwargs.get('device', trainer.device), - 'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), + 'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn), 'evaluate_fn': self.evaluate_fn, 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 95379302..ada31edb 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -30,22 +30,12 @@ class Evaluator: driver: Driver _evaluate_batch_loop: Loop - def __init__( - self, - model, - dataloaders, - metrics: Optional[Union[Dict, Metric]] = None, - driver: Union[str, Driver] = 'torch', - device: Optional[Union[int, List[int], str]] = None, - batch_step_fn: Optional[callable] = None, - evaluate_fn: Optional[str] = None, - input_mapping: Optional[Union[Callable, Dict]] = None, - output_mapping: Optional[Union[Callable, Dict]] = None, - model_wo_auto_param_call: bool = False, - fp16: bool = False, - verbose: int = 1, - **kwargs - ): + def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, + driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, + evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, + input_mapping: Optional[Union[Callable, Dict]] = None, + output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, + fp16: bool = False, verbose: int = 1, **kwargs): """ :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 @@ -54,13 +44,13 @@ class Evaluator: metric ,torchmetrics,allennlpmetrics等。 :param driver: 使用 driver 。 :param device: 使用的设备。 - :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 - DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 - batch_step_fn 函数。 + :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, + 不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 - :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 + :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对 + model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 @@ -96,9 +86,9 @@ class Evaluator: self.device = device self.verbose = verbose - if batch_step_fn is not None: - _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') - self.batch_step_fn = batch_step_fn + if evaluate_batch_step_fn is not None: + _check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') + self.evaluate_batch_step_fn = evaluate_batch_step_fn self.input_mapping = input_mapping self.output_mapping = output_mapping @@ -106,7 +96,7 @@ class Evaluator: if not isinstance(dataloaders, dict): dataloaders = {None: dataloaders} - self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) + self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) self.driver.setup() self.driver.barrier() @@ -235,8 +225,8 @@ class Evaluator: @evaluate_batch_loop.setter def evaluate_batch_loop(self, loop: Loop): - if self.batch_step_fn is not None: - logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored " + if self.evaluate_batch_step_fn is not None: + logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " "when the `evaluate_batch_loop` is also customized.") self._evaluate_batch_loop = loop @@ -249,15 +239,15 @@ class Evaluator: """ self.metrics_wrapper.reset() - def update(self, *args, **kwargs): + def update(self, batch, outputs): """ - 调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。 + 自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 - :param args: - :param kwargs: + :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 + :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 :return: """ - self.metrics_wrapper.update(*args, **kwargs) + self.metrics_wrapper.update(batch, outputs) def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: """ @@ -271,7 +261,7 @@ class Evaluator: @property def metrics_wrapper(self): """ - 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持 + 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper 进行操作。 @@ -283,11 +273,11 @@ class Evaluator: def evaluate_step(self, batch): """ - 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 + 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 返回。 - :param batch: - :return: + :param batch: {evaluate_fn} 函数支持的输入类型 + :return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 """ outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) outputs = match_and_substitute_params(self.output_mapping, outputs) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9a3c30d5..cbec1a01 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -83,10 +83,10 @@ class Trainer(TrainerEventTrigger): :param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; - :param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 - `batch`;默认为 None; - :param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 - 两个参数必须为 `evaluator` 和 `batch`;默认为 None; + :param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以 + 参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。 + :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, + 不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, 则使用模型默认的前向传播函数。 @@ -136,9 +136,9 @@ class Trainer(TrainerEventTrigger): 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 - train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 - evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 + evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 """ self.model = model self.marker = marker @@ -268,21 +268,12 @@ class Trainer(TrainerEventTrigger): progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 progress_bar = progress_bar.name - self.evaluator = Evaluator( - model=model, - dataloaders=evaluate_dataloaders, - metrics=metrics, - driver=self.driver, - device=device, - batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, - input_mapping=input_mapping, - output_mapping=output_mapping, - fp16=fp16, - verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), - progress_bar=progress_bar - ) + self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, + driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=input_mapping, + output_mapping=output_mapping, fp16=fp16, verbose=0, + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), + progress_bar=progress_bar) if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")