|
|
@@ -30,22 +30,12 @@ class Evaluator: |
|
|
|
driver: Driver |
|
|
|
_evaluate_batch_loop: Loop |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model, |
|
|
|
dataloaders, |
|
|
|
metrics: Optional[Union[Dict, Metric]] = None, |
|
|
|
driver: Union[str, Driver] = 'torch', |
|
|
|
device: Optional[Union[int, List[int], str]] = None, |
|
|
|
batch_step_fn: Optional[callable] = None, |
|
|
|
evaluate_fn: Optional[str] = None, |
|
|
|
input_mapping: Optional[Union[Callable, Dict]] = None, |
|
|
|
output_mapping: Optional[Union[Callable, Dict]] = None, |
|
|
|
model_wo_auto_param_call: bool = False, |
|
|
|
fp16: bool = False, |
|
|
|
verbose: int = 1, |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, |
|
|
|
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, |
|
|
|
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, |
|
|
|
input_mapping: Optional[Union[Callable, Dict]] = None, |
|
|
|
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, |
|
|
|
fp16: bool = False, verbose: int = 1, **kwargs): |
|
|
|
""" |
|
|
|
|
|
|
|
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 |
|
|
@@ -54,13 +44,13 @@ class Evaluator: |
|
|
|
metric ,torchmetrics,allennlpmetrics等。 |
|
|
|
:param driver: 使用 driver 。 |
|
|
|
:param device: 使用的设备。 |
|
|
|
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 |
|
|
|
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 |
|
|
|
batch_step_fn 函数。 |
|
|
|
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, |
|
|
|
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 |
|
|
|
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 |
|
|
|
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 |
|
|
|
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 |
|
|
|
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 |
|
|
|
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对 |
|
|
|
model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。 |
|
|
|
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 |
|
|
|
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; |
|
|
|
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 |
|
|
@@ -96,9 +86,9 @@ class Evaluator: |
|
|
|
self.device = device |
|
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
if batch_step_fn is not None: |
|
|
|
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') |
|
|
|
self.batch_step_fn = batch_step_fn |
|
|
|
if evaluate_batch_step_fn is not None: |
|
|
|
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') |
|
|
|
self.evaluate_batch_step_fn = evaluate_batch_step_fn |
|
|
|
|
|
|
|
self.input_mapping = input_mapping |
|
|
|
self.output_mapping = output_mapping |
|
|
@@ -106,7 +96,7 @@ class Evaluator: |
|
|
|
if not isinstance(dataloaders, dict): |
|
|
|
dataloaders = {None: dataloaders} |
|
|
|
|
|
|
|
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) |
|
|
|
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) |
|
|
|
|
|
|
|
self.driver.setup() |
|
|
|
self.driver.barrier() |
|
|
@@ -235,8 +225,8 @@ class Evaluator: |
|
|
|
|
|
|
|
@evaluate_batch_loop.setter |
|
|
|
def evaluate_batch_loop(self, loop: Loop): |
|
|
|
if self.batch_step_fn is not None: |
|
|
|
logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored " |
|
|
|
if self.evaluate_batch_step_fn is not None: |
|
|
|
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " |
|
|
|
"when the `evaluate_batch_loop` is also customized.") |
|
|
|
self._evaluate_batch_loop = loop |
|
|
|
|
|
|
@@ -249,15 +239,15 @@ class Evaluator: |
|
|
|
""" |
|
|
|
self.metrics_wrapper.reset() |
|
|
|
|
|
|
|
def update(self, *args, **kwargs): |
|
|
|
def update(self, batch, outputs): |
|
|
|
""" |
|
|
|
调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。 |
|
|
|
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 |
|
|
|
|
|
|
|
:param args: |
|
|
|
:param kwargs: |
|
|
|
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 |
|
|
|
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
self.metrics_wrapper.update(*args, **kwargs) |
|
|
|
self.metrics_wrapper.update(batch, outputs) |
|
|
|
|
|
|
|
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: |
|
|
|
""" |
|
|
@@ -271,7 +261,7 @@ class Evaluator: |
|
|
|
@property |
|
|
|
def metrics_wrapper(self): |
|
|
|
""" |
|
|
|
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持 |
|
|
|
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 |
|
|
|
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper |
|
|
|
进行操作。 |
|
|
|
|
|
|
@@ -283,11 +273,11 @@ class Evaluator: |
|
|
|
|
|
|
|
def evaluate_step(self, batch): |
|
|
|
""" |
|
|
|
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 |
|
|
|
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 |
|
|
|
返回。 |
|
|
|
|
|
|
|
:param batch: |
|
|
|
:return: |
|
|
|
:param batch: {evaluate_fn} 函数支持的输入类型 |
|
|
|
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 |
|
|
|
""" |
|
|
|
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) |
|
|
|
outputs = match_and_substitute_params(self.output_mapping, outputs) |
|
|
|