diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py new file mode 100644 index 00000000..4ab0ed36 --- /dev/null +++ b/fastNLP/core/metrics/__init__.py @@ -0,0 +1,18 @@ +__all__ = [ + "Metric", + "Accuracy", + 'Backend', + 'AutoBackend', + 'PaddleBackend', + 'TorchBackend', + 'SpanFPreRecMetric', + 'ClassifyFPreRecMetric', + 'func_post_proc' +] + +from .metric import Metric +from .accuracy import Accuracy +from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend +from .span_f1_pre_rec_metric import SpanFPreRecMetric +from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric +from .utils import func_post_proc diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py new file mode 100644 index 00000000..0a60e4d7 --- /dev/null +++ b/fastNLP/core/metrics/accuracy.py @@ -0,0 +1,75 @@ +__all__ = [ + 'Accuracy' +] + +from typing import Union +import warnings + +import numpy as np + +from fastNLP.core.metrics.metric import Metric +from fastNLP.core.metrics.backend import Backend +from fastNLP.core.utils.utils import seq_len_to_mask + + +class Accuracy(Metric): + + def __init__(self, backend: Union[str, Backend, None] = 'auto', + aggregate_when_get_metric: bool = True): + super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) + self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) + self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) + + def get_metric(self) -> dict: + r""" + get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + + :return dict evaluate_result: {"acc": float} + """ + evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} + return evaluate_result + + def update(self, pred, target, seq_len=None): + r""" + evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). + 如果mask也被传进来的话seq_len会被忽略. + """ + # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 + pred = self.tensor2numpy(pred) + target = self.tensor2numpy(target) + if seq_len is not None: + seq_len = self.tensor2numpy(seq_len) + + if seq_len is not None and target.ndim > 1: + max_len = target.shape[1] + masks = seq_len_to_mask(seq_len, max_len) + else: + masks = None + + if pred.ndim == target.ndim: + if np.prod(pred.shape) != np.prod(target.shape): + raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." + f" while target have shape:{target.shape}, " + f"pred have shape: {target.shape}") + + elif pred.ndim == target.ndim + 1: + pred = pred.argmax(axis=-1) + if seq_len is None and target.ndim > 1: + warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") + + else: + raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " + f"{pred.shape[:-1]}, got {target.shape}.") + + if masks is not None: + self.total += masks.sum().item() + self.correct += ((pred == target) * masks).sum().item() + else: + self.total += np.prod(list(pred.shape)).item() + self.correct += (target == pred).sum().item() diff --git a/fastNLP/core/metrics/backend/__init__.py b/fastNLP/core/metrics/backend/__init__.py new file mode 100644 index 00000000..196c9c1b --- /dev/null +++ b/fastNLP/core/metrics/backend/__init__.py @@ -0,0 +1,12 @@ +__all__ = [ + 'Backend', + 'AutoBackend', + 'TorchBackend', + 'PaddleBackend' +] + + +from .backend import Backend +from .auto_backend import AutoBackend +from .torch_backend.backend import TorchBackend +from .paddle_backend.backend import PaddleBackend diff --git a/fastNLP/core/metrics/backend/auto_backend.py b/fastNLP/core/metrics/backend/auto_backend.py new file mode 100644 index 00000000..87ef2393 --- /dev/null +++ b/fastNLP/core/metrics/backend/auto_backend.py @@ -0,0 +1,75 @@ +from typing import Union + +from .backend import Backend +from .torch_backend.backend import TorchBackend +from .paddle_backend.backend import PaddleBackend +from .jittor_backend.backend import JittorBackend + + +class AutoBackend(Backend): + """ + 不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的 + + """ + + def __init__(self, backend: Union[str, Backend, None]): + super(AutoBackend, self).__init__() + if backend != 'auto': + self._convert_backend(backend) + + def _convert_backend(self, backend): + """ + 将AutoBackend转换为合适的Backend对象 + + """ + if isinstance(backend, Backend): + self.__class__ = backend.__class__ + # 如果是str,直接选择就好了 + elif backend == 'torch': + self.__class__ = TorchBackend + elif backend == 'paddle': + self.__class__ = PaddleBackend + elif backend == 'jittor': + self.__class__ = JittorBackend + elif backend is None: + # 不用做任何事情就可以初始化了 + pass + else: + raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.") + self._specified = True + + def choose_real_backend(self, args): + assert not self.is_specified(), "This method should not be called after backend has been specified. " \ + "This must be a bug, please report." + types = [] + for arg in args: + types.append(str(type(arg))) + + torch_types = [] + jittor_types = [] + paddle_types = [] + for type_name in types: + if 'torch' in type_name: + torch_types.append(type_name) + if 'paddle' in type_name: + paddle_types.append(type_name) + if 'jittor' in type_name: + jittor_types.append(type_name) + + # 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上 + if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0: + backend = 'torch' + elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0: + backend = 'jittor' + elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0: + backend = 'paddle' + elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0: + # 直接使用default的backend就好了 + backend = None + else: + types = list(set(torch_types + jittor_types + paddle_types)) + raise RuntimeError( + f"Mixture of tensor type:{types} have been accept, please manually set backend instead of " + f"using backend=auto.") + + self._convert_backend(backend) diff --git a/fastNLP/core/metrics/backend/backend.py b/fastNLP/core/metrics/backend/backend.py new file mode 100644 index 00000000..b9d6ca78 --- /dev/null +++ b/fastNLP/core/metrics/backend/backend.py @@ -0,0 +1,75 @@ +from ..utils import AggregateMethodError + + +class Backend: + """ + Backend 及其子类的所有方法都必须是无状态的。 + + """ + + def __init__(self): + self._specified = False + + def aggregate(self, tensor, method: str): + """ + 聚集结果,并根据method计算后,返回结果 + """ + if method is not None: + return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) + + return tensor + + def create_tensor(self, value: float): + """ + 创建tensor,并且填入value作为值 + """ + return value + + def fill_value(self, tensor, value: float): + """ + 将tensor的值设置为value + + """ + return value + + def get_scalar(self, tensor) -> float: + """ + tensor的saclar值 + + :param tensor: + :return: + """ + return tensor + + def is_specified(self) -> bool: + """ + 判断是否是某种框架的backend + + :return: + """ + return self._specified + + def tensor2numpy(self, tensor): + """ + 将tensor转为numpy + + :param tensor: + :return: + """ + return tensor + + def move_tensor_to_device(self, tensor, device): + """ + """ + return tensor + + def all_gather_object(self, obj, group=None): + """ + 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 + + :param obj: + :param group: + :return: + """ + raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.") + diff --git a/fastNLP/core/metrics/backend/jittor_backend/__init__.py b/fastNLP/core/metrics/backend/jittor_backend/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/fastNLP/core/metrics/backend/jittor_backend/__init__.py @@ -0,0 +1 @@ + diff --git a/fastNLP/core/metrics/backend/jittor_backend/backend.py b/fastNLP/core/metrics/backend/jittor_backend/backend.py new file mode 100644 index 00000000..44831a57 --- /dev/null +++ b/fastNLP/core/metrics/backend/jittor_backend/backend.py @@ -0,0 +1,72 @@ +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.core.metrics.backend import Backend + +if _NEED_IMPORT_JITTOR: + import jittor + + +class JittorBackend(Backend): + + def __init__(self): + super(JittorBackend, self).__init__() + self._specified = True + + def aggregate(self, tensor, method: str): + """ + 聚集结果,并根据method计算后,返回结果 + """ + return tensor + + def create_tensor(self, value: float): + """ + 创建tensor,并且填入value作为值 + """ + value = jittor.Var(value) + return value + + def fill_value(self, tensor, value: float): + """ + 将tensor的值设置为value + + """ + value = jittor.full_like(tensor, value) + return value + + def get_scalar(self, tensor) -> float: + """ + tensor的saclar值 + + :param tensor: + :return: + """ + return tensor.item() + + def is_specified(self) -> bool: + """ + 判断是否是某种框架的backend + + :return: + """ + return self._specified + + def tensor2numpy(self, tensor): + """ + 将tensor转为numpy + + :param tensor: + :return: + """ + if isinstance(tensor, jittor.Var): + return tensor.detach().numpy() + elif isinstance(tensor, np.array): + return tensor + else: + raise ValueError(f"tensor: {tensor} can not convert to ndarray!") + + def move_tensor_to_device(self, tensor, device): + """ + jittor的没有转移设备的函数,因此该函数实际上无效 + """ + return tensor diff --git a/fastNLP/core/metrics/backend/paddle_backend/__init__.py b/fastNLP/core/metrics/backend/paddle_backend/__init__.py new file mode 100644 index 00000000..f1f5a64b --- /dev/null +++ b/fastNLP/core/metrics/backend/paddle_backend/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + 'PaddleBackend' +] + +from .backend import Backend as PaddleBackend diff --git a/fastNLP/core/metrics/backend/paddle_backend/backend.py b/fastNLP/core/metrics/backend/paddle_backend/backend.py new file mode 100644 index 00000000..4028fcf4 --- /dev/null +++ b/fastNLP/core/metrics/backend/paddle_backend/backend.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Any + +import numpy as np + +from fastNLP.core.metrics.backend import Backend +from fastNLP.core.utils.paddle_utils import paddle_to +from fastNLP.core.metrics.utils import AggregateMethodError +from fastNLP.core.utils import is_in_paddle_dist +from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE + +if _NEED_IMPORT_PADDLE: + import paddle + from paddle.fluid.dygraph import parallel_helper + +def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: + gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] + paddle.distributed.all_gather(gathered_result, result, group) + return gathered_result + +class PaddleBackend(Backend): + def __init__(self): + super().__init__() + self._specified = True + + def aggregate(self, tensor, method: str): + """ + 聚集结果,并根据method计算后,返回结果 + """ + if isinstance(tensor, paddle.Tensor): + if parallel_helper._is_parallel_ctx_initialized(): + if method is None: + raise AggregateMethodError(should_have_aggregate_method=True) + tensor = self._gather_all(tensor) + if isinstance(tensor[0], paddle.Tensor): + tensor = paddle.stack(tensor) + # 第一步, aggregate结果 + if method == 'sum': + tensor = paddle.sum(tensor, dim=0) + elif method == 'mean': + tensor = paddle.mean(tensor, dim=0) + elif method == 'max': + tensor, _ = paddle.max(tensor, dim=0) + elif method == 'min': + tensor, _ = paddle.min(tensor, dim=0) + else: + raise AggregateMethodError(should_have_aggregate_method=False) + + return tensor + + def create_tensor(self, value: float): + """ + 创建tensor,并且填入value作为值 + """ + tensor = paddle.ones((1,)).fill_(value) + return tensor + + def fill_value(self, tensor, value: float): + """ + 将tensor的值设置为value + + """ + tensor.fill_(value) + return tensor + + def get_scalar(self, tensor) -> float: + return tensor.item() + + def tensor2numpy(self, tensor) -> np.array: + if isinstance(tensor, paddle.Tensor): + return tensor.cpu().detach().numpy() + elif isinstance(tensor, np.array): + return tensor + else: + raise ValueError(f"tensor: {tensor} can not convert to ndarray!") + + @staticmethod + def _gather_all(result, group: Optional[Any] = None) -> List: + """ + 聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding + """ + # TODO check 正确性 + if group is None: + group = paddle.distributed.get_group(0) + + world_size = group.nranks + paddle.distributed.barrier(group=group) + + # 张量为 标量的情况,简单地gather就好 + if result.ndim == 0: + return _simple_gather_all_tensors(result, group, world_size) + + # 获得 result 的 shape + local_size = paddle.to_tensor(result.shape) + # 将 group 中所有 result 的大小聚合在一起 + local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] + paddle.distributed.all_gather(local_sizes, local_size, group=group) + # 堆叠后,计算出 shape 每一维度的最大值 + max_size = paddle.stack(local_sizes).max(axis=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 如果所有的结果大小相同,那么可以直接聚合 + if all_sizes_equal: + return _simple_gather_all_tensors(result, group, world_size) + + # 否则,padding 与最大的张量对齐 + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = paddle.nn.functional.pad(result, pad_dims) + # 重新进行聚合 + gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] + paddle.distributed.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result + + def move_tensor_to_device(self, tensor, device): + # TODO 如果在这里处理的话,会不会在别的地方引起bug? + if is_in_paddle_dist(): + device = get_device_from_visible(device) + return paddle_to(tensor, device) + diff --git a/fastNLP/core/metrics/backend/torch_backend/__init__.py b/fastNLP/core/metrics/backend/torch_backend/__init__.py new file mode 100644 index 00000000..37312e4d --- /dev/null +++ b/fastNLP/core/metrics/backend/torch_backend/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + 'TorchBackend' +] + + +from .backend import Backend as TorchBackend diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py new file mode 100644 index 00000000..06304a98 --- /dev/null +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -0,0 +1,154 @@ +from typing import Any, List, Optional + +import numpy as np + +from fastNLP.core.metrics.backend import Backend +from fastNLP.core.metrics.utils import AggregateMethodError +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather + + +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + import torch.nn.functional as F + + +def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + dist.all_gather(gathered_result, result, group) + return gathered_result + + +class TorchBackend(Backend): + def __init__(self): + super().__init__() + self._specified = True + + def aggregate(self, tensor, method: str): + """ + 聚集结果,并根据method计算后,返回结果。 + """ + if isinstance(tensor, torch.Tensor): + if dist.is_initialized(): + if method is None: + raise AggregateMethodError(should_have_aggregate_method=True) + tensor = self._gather_all(tensor) + if isinstance(tensor[0], torch.Tensor): + tensor = torch.stack(tensor) + # 第一步, aggregate结果 + if method == 'sum': + tensor = torch.sum(tensor, dim=0) + elif method == 'mean': + tensor = torch.mean(tensor, dim=0) + elif method == 'max': + tensor, _ = torch.max(tensor, dim=0) + elif method == 'min': + tensor, _ = torch.min(tensor, dim=0) + else: + raise AggregateMethodError(should_have_aggregate_method=False) + + return tensor + + def create_tensor(self, value: float): + """ + 创建tensor,并且填入value作为值 + """ + tensor = torch.ones(1).fill_(value) + return tensor + + def fill_value(self, tensor, value: float): + """ + 将tensor的值设置为value + + """ + tensor.fill_(value) + return tensor + + def get_scalar(self, tensor) -> float: + return tensor.item() + + @staticmethod + def _gather_all(result, group: Optional[Any] = None) -> List: + """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. + Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case + tensors are padded, gathered and then trimmed to secure equal workload for all processes. + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + """ + + if group is None: + group = dist.group.WORLD + + # convert tensors to contiguous format + result = result.contiguous() + + world_size = dist.get_world_size(group) + dist.barrier(group=group) + + # if the tensor is scalar, things are easy + if result.ndim == 0: + return _simple_gather_all_tensors(result, group, world_size) + + # 1. Gather sizes of all tensors + local_size = torch.tensor(result.shape, device=result.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(local_sizes, local_size, group=group) + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 2. If shapes are all the same, then do a simple gather: + if all_sizes_equal: + return _simple_gather_all_tensors(result, group, world_size) + + # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = torch.nn.functional.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + dist.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result + + def tensor2numpy(self, tensor) -> np.array: + """ + 将对应的tensor转为numpy对象 + + """ + + if isinstance(tensor, torch.Tensor): + return tensor.cpu().detach().numpy() + elif isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, (float, int)): + return tensor + else: + raise ValueError(f"tensor: {tensor} can not convert to ndarray!") + + @staticmethod + def is_distributed() -> bool: + """ + :return: + """ + return dist.is_available() and dist.is_initialized() + + def move_tensor_to_device(self, tensor, device): + return tensor.to(device) + + def all_gather_object(self, obj, group=None) -> List: + if self.is_distributed(): + obj_list = fastnlp_torch_all_gather(obj, group=group) + return obj_list + return [obj] + diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py new file mode 100644 index 00000000..a2a62d66 --- /dev/null +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -0,0 +1,142 @@ +__all__ = [ + 'ClassifyFPreRecMetric' +] + +from typing import Union, List +from collections import defaultdict +from functools import partial +import warnings + +from .metric import Metric +from .backend import Backend +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.core.utils.utils import seq_len_to_mask + + +def _compute_f_pre_rec(beta_square, tp, fn, fp): + r""" + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) + + return f, pre, rec + + +class ClassifyFPreRecMetric(Metric): + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, + tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, + only_gross: bool = True, f_type='micro', beta=1) -> None: + super(ClassifyFPreRecMetric, self).__init__(backend=backend, + aggregate_when_get_metric=aggregate_when_get_metric) + if f_type not in ('micro', 'macro'): + raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) + + self.ignore_labels = ignore_labels + self.f_type = f_type + self.beta = beta + self.beta_square = self.beta ** 2 + self.only_gross = only_gross + + self.tag_vocab = tag_vocab + + self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ + defaultdict(partial(self.register_element, aggregate_method='sum')),\ + defaultdict(partial(self.register_element, aggregate_method='sum')) + + def get_metric(self) -> dict: + r""" + get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + + :return dict evaluate_result: {"acc": float} + """ + evaluate_result = {} + if not self.only_gross or self.f_type == 'macro': + tags = set(self._fn.keys()) + tags.update(set(self._fp.keys())) + tags.update(set(self._tp.keys())) + f_sum = 0 + pre_sum = 0 + rec_sum = 0 + for tag in tags: + if self.tag_vocab is not None: + tag_name = self.tag_vocab.to_word(tag) + else: + tag_name = int(tag) + tp = self._tp[tag] + fn = self._fn[tag] + fp = self._fp[tag] + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) + f_sum += f + pre_sum += pre + rec_sum += rec + if not self.only_gross and tag != '': # tag!=''防止无tag的情况 + f_key = 'f-{}'.format(tag_name) + pre_key = 'pre-{}'.format(tag_name) + rec_key = 'rec-{}'.format(tag_name) + evaluate_result[f_key] = f + evaluate_result[pre_key] = pre + evaluate_result[rec_key] = rec + + if self.f_type == 'macro': + evaluate_result['f'] = f_sum / len(tags) + evaluate_result['pre'] = pre_sum / len(tags) + evaluate_result['rec'] = rec_sum / len(tags) + + if self.f_type == 'micro': + f, pre, rec = _compute_f_pre_rec(self.beta_square, + sum(self._tp.values()), + sum(self._fn.values()), + sum(self._fp.values())) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec + + + for key, value in evaluate_result.items(): + evaluate_result[key] = round(value, 6) + + return evaluate_result + + def update(self, pred, target, seq_len=None): + pred = self.tensor2numpy(pred) + target = self.tensor2numpy(target) + if seq_len is not None: + seq_len = self.tensor2numpy(seq_len) + + if seq_len is not None and target.ndim > 1: + max_len = target.ndim[-1] + masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) + else: + masks = None + + if pred.ndim == target.ndim: + if len(pred.flatten()) != len(target.flatten()): + raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." + f" while target have element numbers:{len(pred.flatten())}, " + f"pred have element numbers: {len(target.flatten())}") + + pass + elif len(pred.ndim) == len(target.ndim) + 1: + pred = pred.argmax(axis=-1) + if seq_len is None and len(target.ndim) > 1: + warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") + else: + raise RuntimeError(f"when pred have " + f"size:{pred.ndim}, target should have size: {pred.ndim} or " + f"{pred.ndim[:-1]}, got {target.ndim}.") + if masks is not None: + target = target * masks + pred = pred * masks + target_idxes = set(target.reshape(-1).tolist()) + for target_idx in target_idxes: + self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() + self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() + self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() + + diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py new file mode 100644 index 00000000..b3a496bf --- /dev/null +++ b/fastNLP/core/metrics/element.py @@ -0,0 +1,281 @@ +__all__ = [ + 'Element' +] + +import os + +from .backend import Backend, AutoBackend +from fastNLP.core.log import logger +from .utils import AggregateMethodError +from fastNLP.envs.env import FASTNLP_GLOBAL_RANK + + +class Element: + def __init__(self, value: float, aggregate_method, backend: Backend, name=None): + self.init_value = value + self.aggregate_method = aggregate_method + self.name = name + if backend == 'auto': + raise RuntimeError("You have to specify the backend.") + elif isinstance(backend, AutoBackend): + self.backend = backend + else: + self.backend = AutoBackend(backend) + + if self.backend.is_specified(): + value = self.backend.create_tensor(self.init_value) + else: + value = None + self._value = value + self.device = None + + def aggregate(self): + """ + 自动aggregate对应的元素 + + """ + try: + self._value = self.backend.aggregate(self._value, self.aggregate_method) + except AggregateMethodError as e: + msg = 'If you see this message, please report a bug.' + if self.name and e.should_have_aggregate_method: + msg = f"Element:{self.name} has no specified `aggregate_method`." + elif e.should_have_aggregate_method: + msg = "Element has no specified `aggregate_method`." + elif self.name and not e.should_have_aggregate_method: + msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ + f'aggregate_method:{self.aggregate_method}.' + elif not e.should_have_aggregate_method: + msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ + f'aggregate_method:{self.aggregate_method}.' + if e.only_warn: + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + logger.warning(msg) + self._value = self.backend.aggregate(self._value, method=None) + else: + raise RuntimeError(msg) + + def reset(self): + if self.backend.is_specified(): + self._value = self.backend.fill_value(self._value, self.init_value) + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._check_value_initialized() + self._value = value + + @value.getter + def value(self): + self._check_value_initialized() + return self._value + + def get_scalar(self) -> float: + return self.backend.get_scalar(self._value) + + def fill_value(self, value): + self._value = self.backend.fill_value(self._value, value) + + def to(self, device): + # device这里如何处理呢? + if self._value is not None: + self._value = self.backend.move_tensor_to_device(self._value, device) + self.device = device + + def _check_value_initialized(self): + if self._value is None: + assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ + f"initialization." + self._value = self.backend.create_tensor(self.init_value) + if self.device is not None: + self.to(device=self.device) + + def _check_value_when_call(self): + if self.value is None: + prefix = f'Element:`{self.name}`' if self.name else 'Element' + raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " + "element, or use it after it being used by the `Metric.compute()` method.") + + def __add__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value += other.value + else: + self.value += other + return self + + def __radd__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value += other.value + else: + self.value += other + return self + + def __sub__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value -= other.value + else: + self.value -= other + return self + + def __rsub__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value -= other.value + else: + self.value -= other + return self + + def __mul__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value *= other.value + else: + self.value *= other + return self + + def __imul__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value *= other.value + else: + self.value *= other + return self + + def __floordiv__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value //= other.value + else: + self.value //= other + return self + + def __rfloordiv__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value //= other.value + else: + self.value //= other + return self + + def __truediv__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value /= other.value + else: + self.value /= other + return self + + def __rtruediv__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value /= other.value + else: + self.value /= other + return self + + def __mod__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value %= other.value + else: + self.value %= other + return self + + def __rmod__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value /= other.value + else: + self.value /= other + return self + + def __pow__(self, other, modulo=None): + self._check_value_when_call() + if modulo is None: + if isinstance(other, Element): + self.value **= other.value + else: + self.value **= other + else: + if isinstance(other, Element): + self.value = pow(self.value, other.value, modulo) + else: + self.value = pow(self.value, other, modulo) + return self + + def __rpow__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + self.value **= other.value + else: + self.value **= other + return self + + def __lt__(self, other) -> bool: + self._check_value_when_call() + if isinstance(other, Element): + return self.value < other.value + else: + return self.value < other + + def __le__(self, other) -> bool: + self._check_value_when_call() + if isinstance(other, Element): + return self.value <= other.value + else: + return self.value <= other + + def __eq__(self, other): + self._check_value_when_call() + if isinstance(other, Element): + return self.value == other.value + else: + return self.value == other + + def __ne__(self, other) -> bool: + self._check_value_when_call() + if isinstance(other, Element): + return self.value != other.value + else: + return self.value != other + + def __ge__(self, other) -> bool: + self._check_value_when_call() + if isinstance(other, Element): + return self.value >= other.value + else: + return self.value >= other + + def __gt__(self, other) -> bool: + self._check_value_when_call() + if isinstance(other, Element): + return self.value > other.value + else: + return self.value > other + + def __str__(self): + return str(self.value) + + def __repr__(self): + return str(self.value) + + def __getattr__(self, item): + """ + 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 + :param item: + :return: + """ + try: + if self._value is None: + prefix = f'Element:`{self.name}`' if self.name else 'Element' + raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " + "element, or use it after it being used by the `Metric.compute()` method.") + return getattr(self._value, item) + except AttributeError as e: + raise e diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py new file mode 100644 index 00000000..097671da --- /dev/null +++ b/fastNLP/core/metrics/metric.py @@ -0,0 +1,184 @@ +__all__ = [ + 'Metric' +] + +from abc import abstractmethod + +from typing import Union +import functools +from contextlib import contextmanager +import numpy as np + +from fastNLP.core.metrics.backend import Backend, AutoBackend +from fastNLP.core.metrics.element import Element + + +class Metric: + def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): + """ + + :param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() + 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + 当 backend 不支持分布式时,该参数无意义。 + """ + self.backend = AutoBackend(backend) + self._updated = False + self.get_metric = self._sync_get_metric(self.get_metric) + self.update = self._wrap_update(self.update) + self.reset = self._wrap_auto_reset_elements(self.reset) + self.aggregate_when_get_metric = aggregate_when_get_metric + self._cannot_change_element = False + self._elements = {} + + @property + def elements(self) -> dict: + return self._elements + + def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: + """ + 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 + tensor 直接进行加减乘除计算即可。 + 注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 + + :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 + :param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 + :param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 + :param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 + Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 + 一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 + 的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 + jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 + 到任何一种 tensor ,就默认使用 float 类型作为 element 。 + :return: 注册的 Element 对象 + """ + if backend == 'auto': + backend = self.backend + else: + backend = AutoBackend(backend) + + # 当name为None,默认为变量取得变量名 + if name is None: + name = f'ele_var_{len(self._elements)}' + + element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) + self.elements[name] = element + setattr(self, name, element) + return element + + def reset(self): + """ + 如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。 + + """ + pass + + def _wrap_auto_reset_elements(self, reset): + @functools.wraps(reset) + def _wrap_reset(*args, **kwargs): + self._updated = False + for ele in self.elements.values(): + ele.reset() + reset(*args, **kwargs) + + return _wrap_reset + + def _sync_get_metric(self, get_metric): + @functools.wraps(get_metric) + def _wrap_get_metric(*args, **kwargs): + assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \ + f"get_metric()." + with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): + results = get_metric(*args, **kwargs) + return results + + return _wrap_get_metric + + def __setattr__(self, key, value): + if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: + if key in self.elements and value is not self.elements[key]: + raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") + object.__setattr__(self, key, value) + + def _wrap_update(self, update): + @functools.wraps(update) + def _wrap_update(*args, **kwargs): + self.check_backend(*args, **kwargs) + self._cannot_change_element = True + self._updated = True + return update(*args, **kwargs) + + return _wrap_update + + def check_backend(self, *args, **kwargs): + if not self.backend.is_specified(): + _args = [] + for arg in args: + _args.append(arg) + for arg in kwargs.values(): + _args.append(arg) + self.backend.choose_real_backend(_args) + + @contextmanager + def sync(self, recover=True, aggregate=False): + """ + 在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的 + 值恢复到计算前的值。 + + """ + keep_value = {} + if aggregate: + for name, element in self.elements.items(): + # 保存过去的值 + keep_value[name] = element.get_scalar() + # 聚合结果 + element.aggregate() + + yield + + if recover and aggregate: + for name, element in self.elements.items(): + # 恢复结果 + if name in keep_value: + element.fill_value(value=keep_value.get(name)) + + @abstractmethod + def update(self, *args, **kwargs): + raise NotImplementedError() + + @abstractmethod + def get_metric(self) -> dict: + raise NotImplementedError() + + def set_auto_aggregate_when_get_metric(self, flag: bool): + """ + 设置是否在 get_metric 的时候自动 aggregate + + """ + self.aggregate_when_get_metric = flag + + def __getattr__(self, name: str) -> Element: + if 'elements' in self.__dict__: + elements = self.__dict__['elements'] + if name in elements: + return elements[name] + raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) + + def tensor2numpy(self, tensor) -> np.array: + """ + 将tensor向量转为numpy类型变量 + + :param tensor: + :return: + """ + return self.backend.tensor2numpy(tensor) + + def to(self, device): + """ + 将所有的 element 变量移动到 device 设备上 + + :param device: + :return: + """ + for element in self.elements.values(): + element.to(device) diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py new file mode 100644 index 00000000..45b412c8 --- /dev/null +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -0,0 +1,344 @@ +__all__ = [ + 'SpanFPreRecMetric' +] + +from typing import Union, List, Optional +import warnings +from collections import defaultdict +from functools import partial + +from fastNLP.core.metrics.backend import Backend +from fastNLP.core.metrics.metric import Metric +from fastNLP.core.vocabulary import Vocabulary + + +def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): + r""" + 检查vocab中的tag是否与encoding_type是匹配的 + + :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 + :param encoding_type: bio, bmes, bioes, bmeso + :return: + """ + tag_set = set() + unk_token = '' + pad_token = '' + if isinstance(tag_vocab, Vocabulary): + unk_token = tag_vocab.unknown + pad_token = tag_vocab.padding + tag_vocab = tag_vocab.idx2word + for idx, tag in tag_vocab.items(): + if tag in (unk_token, pad_token): + continue + tag = tag[:1].lower() + tag_set.add(tag) + + tags = encoding_type + for tag in tag_set: + assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ + f"encoding_type." + tags = tags.replace(tag, '') # 删除该值 + if tags: # 如果不为空,说明出现了未使用的tag + warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " + "encoding_type.") + + +def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: + r""" + 给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio + + :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 + :return: + """ + tag_set = set() + unk_token = '' + pad_token = '' + if isinstance(tag_vocab, Vocabulary): + unk_token = tag_vocab.unknown + pad_token = tag_vocab.padding + tag_vocab = tag_vocab.idx2word + for idx, tag in tag_vocab.items(): + if tag in (unk_token, pad_token): + continue + tag = tag[:1].lower() + tag_set.add(tag) + + bmes_tag_set = set('bmes') + if tag_set == bmes_tag_set: + return 'bmes' + bio_tag_set = set('bio') + if tag_set == bio_tag_set: + return 'bio' + bmeso_tag_set = set('bmeso') + if tag_set == bmeso_tag_set: + return 'bmeso' + bioes_tag_set = set('bioes') + if tag_set == bioes_tag_set: + return 'bioes' + raise RuntimeError("encoding_type cannot be inferred automatically. Only support " + "'bio', 'bmes', 'bmeso', 'bioes' type.") + + +def _bmes_tag_to_spans(tags, ignore_labels=None): + r""" + 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 + 返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) + 也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: + spans[-1][1][1] = idx + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels + ] + + +def _bmeso_tag_to_spans(tags, ignore_labels=None): + r""" + 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (左闭右开区间) + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif bmes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels + ] + + +def _bioes_tag_to_spans(tags, ignore_labels=None): + r""" + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (左闭右开区间) + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bioes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bioes_tag, label = tag[:1], tag[2:] + if bioes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif bioes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bioes_tag = bioes_tag + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels + ] + + +def _bio_tag_to_spans(tags, ignore_labels=None): + r""" + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (左闭右开区间) + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bio_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bio_tag, label = tag[:1], tag[2:] + if bio_tag == 'b': + spans.append((label, [idx, idx])) + elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif bio_tag == 'o': # o tag does not count + pass + else: + spans.append((label, [idx, idx])) + prev_bio_tag = bio_tag + return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] + + +def _compute_f_pre_rec(beta_square, tp, fn, fp): + r""" + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) + + return f, pre, rec + + +class SpanFPreRecMetric(Metric): + + def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, + encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', + beta=1, aggregate_when_get_metric: bool = True,) -> None: + super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) + if f_type not in ('micro', 'macro'): + raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) + if not isinstance(tag_vocab, Vocabulary): + raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) + if encoding_type: + encoding_type = encoding_type.lower() + _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) + self.encoding_type = encoding_type + else: + self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) + + if self.encoding_type == 'bmes': + self.tag_to_span_func = _bmes_tag_to_spans + elif self.encoding_type == 'bio': + self.tag_to_span_func = _bio_tag_to_spans + elif self.encoding_type == 'bmeso': + self.tag_to_span_func = _bmeso_tag_to_spans + elif self.encoding_type == 'bioes': + self.tag_to_span_func = _bioes_tag_to_spans + else: + raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") + + self.ignore_labels = ignore_labels + self.f_type = f_type + self.beta = beta + self.beta_square = self.beta ** 2 + self.only_gross = only_gross + self.tag_vocab = tag_vocab + + self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) + self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) + self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) + + def get_metric(self) -> dict: + evaluate_result = {} + if not self.only_gross or self.f_type == 'macro': + tags = set(self._false_negatives.keys()) + tags.update(set(self._false_positives.keys())) + tags.update(set(self._true_positives.keys())) + f_sum = 0 + pre_sum = 0 + rec_sum = 0 + for tag in tags: + tp = self._true_positives[tag].get_scalar() + fn = self._false_negatives[tag].get_scalar() + fp = self._false_positives[tag].get_scalar() + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) + f_sum += f + pre_sum += pre + rec_sum += rec + if not self.only_gross and tag != '': # tag!=''防止无tag的情况 + f_key = 'f-{}'.format(tag) + pre_key = 'pre-{}'.format(tag) + rec_key = 'rec-{}'.format(tag) + evaluate_result[f_key] = f + evaluate_result[pre_key] = pre + evaluate_result[rec_key] = rec + + if self.f_type == 'macro': + evaluate_result['f'] = f_sum / len(tags) + evaluate_result['pre'] = pre_sum / len(tags) + evaluate_result['rec'] = rec_sum / len(tags) + + if self.f_type == 'micro': + f, pre, rec = _compute_f_pre_rec(self.beta_square, + sum(val.get_scalar() for val in self._true_positives.values()), + sum(val.get_scalar() for val in self._false_negatives.values()), + sum(val.get_scalar() for val in self._false_positives.values())) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec + + for key, value in evaluate_result.items(): + evaluate_result[key] = round(value, 6) + + return evaluate_result + + def update(self, pred, target, seq_len: Optional[List] = None) -> None: + r"""update函数将针对一个批次的预测结果做评价指标的累计 + + :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 + :param target: [batch, seq_len], 真实值 + :param seq_len: [batch] 文本长度标记 + :return: + """ + pred = self.tensor2numpy(pred) + target = self.tensor2numpy(target) + + if pred.ndim == target.ndim and target.ndim == 2: + pass + + elif pred.ndim == target.ndim + 1 and target.ndim == 2: + num_classes = pred.shape[-1] + pred = pred.argmax(axis=-1) + if (target >= num_classes).any(): + raise ValueError("A gold label passed to SpanBasedF1Metric contains an " + "id >= {}, the number of classes.".format(num_classes)) + else: + raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or " + f"{pred.shape[:-1]}, got {target.ndim}.") + + batch_size = pred.shape[0] + pred = pred.tolist() + target = target.tolist() + for i in range(batch_size): + pred_tags = pred[i][:int(seq_len[i])] + gold_tags = target[i][:int(seq_len[i])] + + pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] + gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] + + pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) + gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) + + for span in pred_spans: + if span in gold_spans: + self._true_positives[span[0]] += 1 + gold_spans.remove(span) + else: + self._false_positives[span[0]] += 1 + for span in gold_spans: + self._false_negatives[span[0]] += 1 diff --git a/fastNLP/core/metrics/utils.py b/fastNLP/core/metrics/utils.py new file mode 100644 index 00000000..1363282a --- /dev/null +++ b/fastNLP/core/metrics/utils.py @@ -0,0 +1,91 @@ +__all__ = [ + 'func_post_proc' +] + +from typing import Any +from functools import wraps +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP.envs.utils import _module_available + +_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') +if _IS_TORCHMETRICS_AVAILABLE: + from torchmetrics import Metric as torchmetrics_Metric + +_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') +if _IS_ALLENNLP_AVAILABLE: + from allennlp.training.metrics import Metric as allennlp_Metric + +if _NEED_IMPORT_PADDLE: + from paddle.metric import Metric as paddle_Metric + + +def _is_torchmetrics_metric(metric: Any) -> bool: + """ + 检查输入的对象是否为torchmetrics对象 + + :param metric: + :return: + """ + if _IS_TORCHMETRICS_AVAILABLE: + return isinstance(metric, torchmetrics_Metric) + else: + return False + + +def _is_allennlp_metric(metric: Any) -> bool: + """ + 检查输入的对象是否为allennlp对象 + + :param metric: + :return: + """ + if _IS_ALLENNLP_AVAILABLE: + return isinstance(metric, allennlp_Metric) + else: + return False + + +def _is_paddle_metric(metric: Any) -> bool: + """ + 检查输入的对象是否为allennlp对象 + + :param metric: + :return: + """ + if _NEED_IMPORT_PADDLE: + return isinstance(metric, paddle_Metric) + else: + return False + + +def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': + """ + 将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 + 后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 + + :param metric: metric对象 + :param fn: 作用于 metric 的 accumulate 方法的返回值 + :param method_name: 一般来说,对于 + :return: metric + """ + assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ + f"Parameter `metric` must have a {method_name} function." + assert callable(fn), "Parameter `fn` must be callable." + + func = getattr(metric, method_name) + + @wraps(func) + def wrap_method(*args, **kwargs): + res = func(*args, **kwargs) + return fn(res) + + wrap_method.__wrapped_by_func_post_proc__ = True + setattr(metric, method_name, wrap_method) + return metric + + +class AggregateMethodError(BaseException): + def __init__(self, should_have_aggregate_method, only_warn=False): + super(AggregateMethodError, self).__init__(self) + self.should_have_aggregate_method = should_have_aggregate_method + self.only_warn = only_warn