|
- from typing import Union, List, Optional, Dict, Callable
- from functools import partial
- from dataclasses import is_dataclass
- import sys
-
- __all__ = [
- 'Evaluator'
- ]
-
- from fastNLP.core.drivers import Driver, TorchDriver
- from ..drivers.choose_driver import choose_driver
- from .loops import Loop, EvaluateBatchLoop
- from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
- match_and_substitute_params, f_rich_progress, flat_nest_dict
- from fastNLP.core.metrics import Metric
- from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
- from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
- from fastNLP.core.utils.utils import _check_valid_parameters_number
- from fastNLP.core.log import logger
-
-
- class Evaluator:
- """
- 用于对数据进行评测。
-
- :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。
- :param dataloaders: 待评测的数据集。如果为多个,请使用 dict 传入。
- :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的
- metric ,torchmetrics,allennlpmetrics 等。
- :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["torch", "jittor", "paddle"]``
- 其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device``
- 的设置。
- :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None,
- 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间
- 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据
- 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景);
-
- device 的可选输入如下所示:
-
- * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等;
- * *torch.device*: 例如 'torch.device("cuda:0")';
- * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`;
- * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值;
- * *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None;
-
- .. note::
-
- 如果希望使用 ``TorchDDPDriver``,在初始化 ``Trainer`` 时您应当使用::
-
- Trainer(driver="torch", device=[0, 1])
-
- 注意如果这时 ``device=[0]``,我们仍旧会使用 ``TorchDDPDriver``。
-
- 如果希望使用 ``TorchSingleDriver``,则在初始化 ``Trainer`` 时您应当使用::
-
- Trainer(driver="torch", device=0)
-
- .. warning::
-
- 注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``!
-
- 例如,当您使用::
-
- python -m torch.distributed.launch --nproc_per_node 2 train.py
-
- 来使用 ``TorchDDPDriver`` 时,此时参数 ``device`` 不再有效(不管您是否自己初始化 ``init_process_group``),我们将直接
- 通过 ``torch.device(f"cuda:{local_rank}")`` 来获取当前进程所使用的的具体的 gpu 设备。因此此时您需要使用 ``os.environ["CUDA_VISIBLE_DEVICES"]``
- 来指定要使用的具体的 gpu 设备。
-
- 另一点需要注意的是,当您没有选择自己初始化 ``init_process_group`` 时,我们仍旧会帮助您把模型和数据迁移到当前进程所使用的
- 具体的 gpu 设备上。但是如果您选择自己在 ``Trainer`` 初始化前(意味着在 ``driver`` 的 ``setup`` 前)初始化 ``init_process_group``,
- 那么对于模型的迁移应当完全由您自己来完成。此时对于数据的迁移,如果您在 ``Trainer`` 初始化时指定了参数 ``data_device``,那么
- 我们会将数据迁移到 ``data_device`` 上;如果其为 None,那么将数据迁移到正确的设备上应当由您自己来完成。
-
- 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`fastNLP.core.drivers.torch_driver.TorchDDPDriver`。
-
- :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
- 不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.EvaluateBatchLoop.batch_step_fn`
- 函数。
- :param evaluate_fn: 用来控制 ``Evaluator`` 在评测时调用模型的哪一个函数,例如是 ``evaluate_step`` 还是框架默认的前向接口;
- 默认为 ``None``,如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` , 如果在模型的定义类中没有找到该方法,
- 则使用模型默认的前向传播函数,例如对于 pytorch 来说就是 ``forward``。
-
- .. note::
- 查找逻辑如下所示:
-
- 1. 如果 ``evaluate_fn`` 为 None,那么在 model 的类 Model 中寻找方法 ``Model.evaluate_step``;如果没有找到,那么默认使用 ``Model.forward``;
- 2. 如果 ``evaluate_fn`` 为一个字符串,例如 'my_step_fn',那么我们首先会在 model 的类 Model 中寻找方法 ``Model.my_step_fn``,
- 如果没有找到,那么会直接报错;
-
- :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的数据后,应当做怎样的映射处理:
-
- 1. 如果 ``input_mapping`` 是一个字典:
-
- 1. 如果此时 batch 也是一个 ``Dict``,那么我们会把 batch 中同样在 ``input_mapping`` 中的 key 修改为 ``input_mapping`` 的对应 ``key`` 的 ``value``;
- 2. 如果此时 batch 是一个 ``dataclass``,那么我们会先将其转换为一个 ``Dict``,然后再进行上述转换;
- 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错;
- 2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里,并将其返回值
- 送入模型中;
-
- :param output_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 model 的返回值后,应当做怎样的映射处理:
-
- 1. 如果 ``output_mapping`` 是一个字典:
-
- 1. 如果此时 batch 也是一个 ``Dict``,那么我们会把输出中同样在 ``output_mapping`` 中的 key 修改为 ``output_mapping`` 的对应 ``key`` 的 ``value``;
- 例如输出结果为 {'a': 1},而 output_mapping={'a':'b'} ,那么结果就是 {'b': 1}
- 2. 如果此时 batch 是一个 ``dataclass``,那么我们会先将其转换为一个 ``Dict``,然后再进行上述转换;
- 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错;
- 2. 如果 ``output_mapping`` 是一个函数,我们将会把结果传入该函数里,并将其返回值送入到 metric 中。
-
- :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 evaluate_fn 函数的参数的行为;
- 如果该值为 True,并且当 batch 为字典时,我们会根据 evaluate_fn 所需要的参数从 batch 中提取对应的对象,传入到 evaluate_fn 函数中;如果该值
- 为 False,那么我们会将 batch 直接透传给 evaluate_fn 函数。
- :param fp16: 是否使用 fp16 。
- :param verbose: 是否打印 evaluate 的结果。
- :kwargs:
- * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数:
- * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
- {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
- * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
- * *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
- 注意如果 model_device 为 None,那么 data_device 不会起作用;
- * *model_use_eval_mode* (``bool``) --
- 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的
- dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
- 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
- * *use_dist_sampler* --
- 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
- 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
- * *output_from_new_proc* --
- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
- ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
- log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
- * *progress_bar* --
- evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
- 到当前terminal为交互型则使用 rich,否则使用 raw。
- """
-
- driver: Driver
- _evaluate_batch_loop: Loop
-
- def __init__(self, model, dataloaders, metrics: Optional[Dict] = None,
- driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None,
- evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None,
- input_mapping: Optional[Union[Callable, Dict]] = None,
- output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False,
- fp16: bool = False, verbose: int = 1, **kwargs):
- self.model = model
- self.metrics = metrics
- self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call,
- **kwargs)
-
- if dataloaders is None:
- raise ValueError("Parameter `dataloaders` can not be None.")
- self.dataloaders = dataloaders
- self.device = device
- self.verbose = verbose
-
- if evaluate_batch_step_fn is not None:
- _check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
- self.evaluate_batch_step_fn = evaluate_batch_step_fn
-
- self.input_mapping = input_mapping
- self.output_mapping = output_mapping
-
- if not isinstance(dataloaders, dict):
- dataloaders = {None: dataloaders}
-
- self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
-
- self.driver.setup()
- self.driver.barrier()
-
- self.separator = kwargs.get('separator', '#')
- self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
- use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
- if use_dist_sampler:
- self._dist_sampler = "unrepeatdist"
- else:
- self._dist_sampler = None
- self._metric_wrapper = None
- _ = self.metrics_wrapper # 触发检查
-
- if evaluate_fn is not None and not isinstance(evaluate_fn, str):
- raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.")
- self._evaluate_step, self._evaluate_step_signature_fn = \
- self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
- self.evaluate_fn = evaluate_fn
-
- self.dataloaders = {}
- for name, dl in dataloaders.items(): # 替换为正确的 sampler
- dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
- self.dataloaders[name] = dl
-
- self.progress_bar = kwargs.get('progress_bar', 'auto')
- if self.progress_bar == 'auto':
- self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich'
-
- self.driver.barrier()
-
- def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
- """
- 返回一个字典类型的数据,其中 key 为 metric 的名字,value 为对应结果。返回的字典中,key 的命名规则如下
- ``metric_indicator_name#metric_name#dataloader_name`` ,其中 ``metric_indicator_name`` 是由使用的 metric 返回的结果
- 决定的,仅当 metric 的结果返回是 dict 类型是才有该值;``metric_name`` 则由 ``Evaluator`` 初始化时传入的 ``metrics`` 参数
- 决定;``dataloader_name``仅在传入的 ``dataloaders`` 为 dict 时会有。此外其中的 ``#`` 符号通过 ``Evaluator`` 初始化
- 参数 ``separator`` 进行设置。
-
- :param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。
- :return:
- """
- assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
- assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."
-
- metric_results = {}
- self.reset()
- evaluate_context = self.driver.get_evaluate_context()
- self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')
- with evaluate_context():
- try:
- for dataloader_name, dataloader in self.dataloaders.items():
- self.driver.barrier()
- if num_eval_batch_per_dl != -1:
- dataloader = _TruncatedDataLoader(dataloader, num_eval_batch_per_dl)
- self.driver.set_sampler_epoch(dataloader, -1)
- self.start_progress_bar(total=len(dataloader), dataloader_name=dataloader_name)
- self.cur_dataloader_name = dataloader_name
- results = self.evaluate_batch_loop.run(self, dataloader)
- self.remove_progress_bar(dataloader_name)
- metric_results[dataloader_name] = results
- self.reset()
- self.driver.barrier()
- except BaseException as e:
- self.driver.on_exception()
- raise e
- finally:
- self.finally_progress_bar()
- if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。
- metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
- if self.verbose:
- if self.progress_bar == 'rich':
- f_rich_progress.print(metric_results)
- else:
- logger.info(metric_results)
- self.driver.set_model_mode(mode='train')
-
- return metric_results
-
- def start_progress_bar(self, total: int, dataloader_name):
- if self.progress_bar == 'rich':
- if dataloader_name is None:
- desc = f'Eval. Batch:0'
- else:
- desc = f'Eval. on {dataloader_name} Batch:0'
- self._rich_task_id = f_rich_progress.add_task(description=desc, total=total)
- elif self.progress_bar == 'raw':
- desc = 'Evaluation starts'
- if dataloader_name is not None:
- desc += f' on {dataloader_name}'
- logger.info('\n' + "*" * 10 + desc + '*' * 10)
-
- def update_progress_bar(self, batch_idx, dataloader_name, **kwargs):
- if dataloader_name is None:
- desc = f'Eval. Batch:{batch_idx}'
- else:
- desc = f'Eval. on {dataloader_name} Batch:{batch_idx}'
- if self.progress_bar == 'rich':
- assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \
- "update_progress_bar()"
- f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''),
- advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True),
- visible=kwargs.get('visible', True))
- elif self.progress_bar == 'raw':
- if self.verbose > 1:
- logger.info(desc)
-
- def remove_progress_bar(self, dataloader_name):
- if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
- f_rich_progress.destroy_task(self._rich_task_id)
- delattr(self, '_rich_task_id')
- elif self.progress_bar == 'raw':
- desc = 'Evaluation ends'
- if dataloader_name is not None:
- desc += f' on {dataloader_name}'
- logger.info("*" * 10 + desc + '*' * 10 + '\n')
-
- def finally_progress_bar(self):
- if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
- f_rich_progress.destroy_task(self._rich_task_id)
- delattr(self, '_rich_task_id')
-
- @property
- def evaluate_batch_loop(self):
- return self._evaluate_batch_loop
-
- @evaluate_batch_loop.setter
- def evaluate_batch_loop(self, loop: Loop):
- if self.evaluate_batch_step_fn is not None:
- logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
- "when the `evaluate_batch_loop` is also customized.")
- self._evaluate_batch_loop = loop
-
- def reset(self):
- """
- 调用所有 metric 的 reset() 方法,清除累积的状态。
-
- :return:
- """
- self.metrics_wrapper.reset()
-
- def update(self, batch, outputs):
- """
- 自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。
-
- :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。
- :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。
- :return:
- """
- self.metrics_wrapper.update(batch, outputs)
-
- def get_metric(self) -> Dict:
- """
- 调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。
-
- :return:
- """
- return self.metrics_wrapper.get_metric()
-
- @property
- def metrics_wrapper(self):
- """
- 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持
- 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper
- 进行操作。
-
- Returns:
- """
- if self._metric_wrapper is None:
- self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self)
- return self._metric_wrapper
-
- def evaluate_step(self, batch):
- """
- 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再
- 返回。
-
- :param batch: {evaluate_fn} 函数支持的输入类型
- :return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。
- """
- outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
- outputs = match_and_substitute_params(self.output_mapping, outputs)
- return outputs
-
- @property
- def metrics(self):
- """
- 返回用户传入的 metrics 对象。
-
- :return:
- """
- return self._metrics
-
- @metrics.setter
- def metrics(self, metrics):
- self._metrics = metrics
-
- def move_data_to_device(self, batch):
- return self.driver.move_data_to_device(batch)
-
-
- class _MetricsWrapper:
- """
- 注意 metrics 的输入只支持:Dict[str, Metric];
- 并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。
-
- """
-
- def __init__(self, metrics, evaluator):
- self.evaluator = evaluator
- self._metrics = []
- self._metric_names = []
- if metrics is not None:
- if not isinstance(metrics, Dict):
- raise TypeError("Parameter `metrics` can only be `Dict` type.")
- for metric_name, metric in metrics.items():
- # 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上;
- if _is_torchmetrics_metric(metric) and isinstance(evaluator.driver, TorchDriver):
- # torchmetrics 是默认自动开启了多卡的
- evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device)
- elif isinstance(metric, Metric):
- # 如果数据是分布式的,但是不aggregate的话可能有问题
- if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
- logger.rank_zero_warning(
- "You have replace the sampler as distributed sampler when evaluation, but your metric "
- f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True)
- if metric.aggregate_when_get_metric is None:
- metric.aggregate_when_get_metric = evaluator._dist_sampler is not None
-
- metric.to(evaluator.driver.data_device)
- self._metric_names.append(metric_name)
- self._metrics.append(metric)
-
- def update(self, batch, outputs):
- if is_dataclass(outputs):
- outputs = dataclass_to_dict(outputs)
- for metric in self._metrics:
- args = []
- if not isinstance(batch, dict):
- logger.warning_once(
- f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
- f"the output of model to update metric.")
- else:
- args.append(batch)
- if not isinstance(outputs, dict):
- raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
- f" return a dict from your model or use `output_mapping` to convert it into dict type.")
- if isinstance(metric, Metric):
- # 这样在 auto_param_call 报错的时候才清晰。
- auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
- elif _is_torchmetrics_metric(metric):
- auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
- elif _is_allennlp_metric(metric):
- auto_param_call(metric.__call__, outputs, *args)
- elif _is_paddle_metric(metric):
- res = auto_param_call(metric.compute, outputs, *args)
- metric.update(res)
-
- def reset(self):
- """
- 将 Metric 中的状态重新设置。
-
- :return:
- """
- for metric in self._metrics:
- if _is_allennlp_metric(metric):
- metric.get_metric(reset=True)
- elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric):
- metric.reset()
-
- def get_metric(self) -> Dict:
- """
- 调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式
- 返回。
-
- :return:
- """
- results = {}
- for metric_name, metric in zip(self._metric_names, self._metrics):
- if isinstance(metric, Metric):
- _results = metric.get_metric()
- elif _is_allennlp_metric(metric):
- _results = metric.get_metric(reset=False)
- elif _is_torchmetrics_metric(metric):
- _results = metric.compute()
- elif _is_paddle_metric(metric):
- _results = metric.accumulate()
- else:
- raise RuntimeError(f"Not support `{type(metric)}` for now.")
- if _results is not None:
- results[metric_name] = _results
- else:
- logger.warning_once(f"Metric:{metric_name} returns None when getting metric results.")
- return results
|