@@ -0,0 +1,18 @@ | |||||
__all__ = [ | |||||
"Metric", | |||||
"Accuracy", | |||||
'Backend', | |||||
'AutoBackend', | |||||
'PaddleBackend', | |||||
'TorchBackend', | |||||
'SpanFPreRecMetric', | |||||
'ClassifyFPreRecMetric', | |||||
'func_post_proc' | |||||
] | |||||
from .metric import Metric | |||||
from .accuracy import Accuracy | |||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | |||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | |||||
from .utils import func_post_proc |
@@ -0,0 +1,75 @@ | |||||
__all__ = [ | |||||
'Accuracy' | |||||
] | |||||
from typing import Union | |||||
import warnings | |||||
import numpy as np | |||||
from fastNLP.core.metrics.metric import Metric | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.utils.utils import seq_len_to_mask | |||||
class Accuracy(Metric): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||||
aggregate_when_get_metric: bool = True): | |||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||||
def get_metric(self) -> dict: | |||||
r""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
:return dict evaluate_result: {"acc": float} | |||||
""" | |||||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} | |||||
return evaluate_result | |||||
def update(self, pred, target, seq_len=None): | |||||
r""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
如果mask也被传进来的话seq_len会被忽略. | |||||
""" | |||||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | |||||
pred = self.tensor2numpy(pred) | |||||
target = self.tensor2numpy(target) | |||||
if seq_len is not None: | |||||
seq_len = self.tensor2numpy(seq_len) | |||||
if seq_len is not None and target.ndim > 1: | |||||
max_len = target.shape[1] | |||||
masks = seq_len_to_mask(seq_len, max_len) | |||||
else: | |||||
masks = None | |||||
if pred.ndim == target.ndim: | |||||
if np.prod(pred.shape) != np.prod(target.shape): | |||||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||||
f" while target have shape:{target.shape}, " | |||||
f"pred have shape: {target.shape}") | |||||
elif pred.ndim == target.ndim + 1: | |||||
pred = pred.argmax(axis=-1) | |||||
if seq_len is None and target.ndim > 1: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | |||||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||||
f"{pred.shape[:-1]}, got {target.shape}.") | |||||
if masks is not None: | |||||
self.total += masks.sum().item() | |||||
self.correct += ((pred == target) * masks).sum().item() | |||||
else: | |||||
self.total += np.prod(list(pred.shape)).item() | |||||
self.correct += (target == pred).sum().item() |
@@ -0,0 +1,12 @@ | |||||
__all__ = [ | |||||
'Backend', | |||||
'AutoBackend', | |||||
'TorchBackend', | |||||
'PaddleBackend' | |||||
] | |||||
from .backend import Backend | |||||
from .auto_backend import AutoBackend | |||||
from .torch_backend.backend import TorchBackend | |||||
from .paddle_backend.backend import PaddleBackend |
@@ -0,0 +1,75 @@ | |||||
from typing import Union | |||||
from .backend import Backend | |||||
from .torch_backend.backend import TorchBackend | |||||
from .paddle_backend.backend import PaddleBackend | |||||
from .jittor_backend.backend import JittorBackend | |||||
class AutoBackend(Backend): | |||||
""" | |||||
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的 | |||||
""" | |||||
def __init__(self, backend: Union[str, Backend, None]): | |||||
super(AutoBackend, self).__init__() | |||||
if backend != 'auto': | |||||
self._convert_backend(backend) | |||||
def _convert_backend(self, backend): | |||||
""" | |||||
将AutoBackend转换为合适的Backend对象 | |||||
""" | |||||
if isinstance(backend, Backend): | |||||
self.__class__ = backend.__class__ | |||||
# 如果是str,直接选择就好了 | |||||
elif backend == 'torch': | |||||
self.__class__ = TorchBackend | |||||
elif backend == 'paddle': | |||||
self.__class__ = PaddleBackend | |||||
elif backend == 'jittor': | |||||
self.__class__ = JittorBackend | |||||
elif backend is None: | |||||
# 不用做任何事情就可以初始化了 | |||||
pass | |||||
else: | |||||
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.") | |||||
self._specified = True | |||||
def choose_real_backend(self, args): | |||||
assert not self.is_specified(), "This method should not be called after backend has been specified. " \ | |||||
"This must be a bug, please report." | |||||
types = [] | |||||
for arg in args: | |||||
types.append(str(type(arg))) | |||||
torch_types = [] | |||||
jittor_types = [] | |||||
paddle_types = [] | |||||
for type_name in types: | |||||
if 'torch' in type_name: | |||||
torch_types.append(type_name) | |||||
if 'paddle' in type_name: | |||||
paddle_types.append(type_name) | |||||
if 'jittor' in type_name: | |||||
jittor_types.append(type_name) | |||||
# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上 | |||||
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||||
backend = 'torch' | |||||
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0: | |||||
backend = 'jittor' | |||||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0: | |||||
backend = 'paddle' | |||||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||||
# 直接使用default的backend就好了 | |||||
backend = None | |||||
else: | |||||
types = list(set(torch_types + jittor_types + paddle_types)) | |||||
raise RuntimeError( | |||||
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of " | |||||
f"using backend=auto.") | |||||
self._convert_backend(backend) |
@@ -0,0 +1,75 @@ | |||||
from ..utils import AggregateMethodError | |||||
class Backend: | |||||
""" | |||||
Backend 及其子类的所有方法都必须是无状态的。 | |||||
""" | |||||
def __init__(self): | |||||
self._specified = False | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据method计算后,返回结果 | |||||
""" | |||||
if method is not None: | |||||
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建tensor,并且填入value作为值 | |||||
""" | |||||
return value | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将tensor的值设置为value | |||||
""" | |||||
return value | |||||
def get_scalar(self, tensor) -> float: | |||||
""" | |||||
tensor的saclar值 | |||||
:param tensor: | |||||
:return: | |||||
""" | |||||
return tensor | |||||
def is_specified(self) -> bool: | |||||
""" | |||||
判断是否是某种框架的backend | |||||
:return: | |||||
""" | |||||
return self._specified | |||||
def tensor2numpy(self, tensor): | |||||
""" | |||||
将tensor转为numpy | |||||
:param tensor: | |||||
:return: | |||||
""" | |||||
return tensor | |||||
def move_tensor_to_device(self, tensor, device): | |||||
""" | |||||
""" | |||||
return tensor | |||||
def all_gather_object(self, obj, group=None): | |||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: | |||||
:param group: | |||||
:return: | |||||
""" | |||||
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.") | |||||
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,72 @@ | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.metrics.backend import Backend | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
class JittorBackend(Backend): | |||||
def __init__(self): | |||||
super(JittorBackend, self).__init__() | |||||
self._specified = True | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据method计算后,返回结果 | |||||
""" | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建tensor,并且填入value作为值 | |||||
""" | |||||
value = jittor.Var(value) | |||||
return value | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将tensor的值设置为value | |||||
""" | |||||
value = jittor.full_like(tensor, value) | |||||
return value | |||||
def get_scalar(self, tensor) -> float: | |||||
""" | |||||
tensor的saclar值 | |||||
:param tensor: | |||||
:return: | |||||
""" | |||||
return tensor.item() | |||||
def is_specified(self) -> bool: | |||||
""" | |||||
判断是否是某种框架的backend | |||||
:return: | |||||
""" | |||||
return self._specified | |||||
def tensor2numpy(self, tensor): | |||||
""" | |||||
将tensor转为numpy | |||||
:param tensor: | |||||
:return: | |||||
""" | |||||
if isinstance(tensor, jittor.Var): | |||||
return tensor.detach().numpy() | |||||
elif isinstance(tensor, np.array): | |||||
return tensor | |||||
else: | |||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||||
def move_tensor_to_device(self, tensor, device): | |||||
""" | |||||
jittor的没有转移设备的函数,因此该函数实际上无效 | |||||
""" | |||||
return tensor |
@@ -0,0 +1,5 @@ | |||||
__all__ = [ | |||||
'PaddleBackend' | |||||
] | |||||
from .backend import Backend as PaddleBackend |
@@ -0,0 +1,126 @@ | |||||
from typing import List, Optional, Any | |||||
import numpy as np | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.utils.paddle_utils import paddle_to | |||||
from fastNLP.core.metrics.utils import AggregateMethodError | |||||
from fastNLP.core.utils import is_in_paddle_dist | |||||
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
from paddle.fluid.dygraph import parallel_helper | |||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | |||||
paddle.distributed.all_gather(gathered_result, result, group) | |||||
return gathered_result | |||||
class PaddleBackend(Backend): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._specified = True | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据method计算后,返回结果 | |||||
""" | |||||
if isinstance(tensor, paddle.Tensor): | |||||
if parallel_helper._is_parallel_ctx_initialized(): | |||||
if method is None: | |||||
raise AggregateMethodError(should_have_aggregate_method=True) | |||||
tensor = self._gather_all(tensor) | |||||
if isinstance(tensor[0], paddle.Tensor): | |||||
tensor = paddle.stack(tensor) | |||||
# 第一步, aggregate结果 | |||||
if method == 'sum': | |||||
tensor = paddle.sum(tensor, dim=0) | |||||
elif method == 'mean': | |||||
tensor = paddle.mean(tensor, dim=0) | |||||
elif method == 'max': | |||||
tensor, _ = paddle.max(tensor, dim=0) | |||||
elif method == 'min': | |||||
tensor, _ = paddle.min(tensor, dim=0) | |||||
else: | |||||
raise AggregateMethodError(should_have_aggregate_method=False) | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建tensor,并且填入value作为值 | |||||
""" | |||||
tensor = paddle.ones((1,)).fill_(value) | |||||
return tensor | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将tensor的值设置为value | |||||
""" | |||||
tensor.fill_(value) | |||||
return tensor | |||||
def get_scalar(self, tensor) -> float: | |||||
return tensor.item() | |||||
def tensor2numpy(self, tensor) -> np.array: | |||||
if isinstance(tensor, paddle.Tensor): | |||||
return tensor.cpu().detach().numpy() | |||||
elif isinstance(tensor, np.array): | |||||
return tensor | |||||
else: | |||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||||
@staticmethod | |||||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||||
""" | |||||
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | |||||
""" | |||||
# TODO check 正确性 | |||||
if group is None: | |||||
group = paddle.distributed.get_group(0) | |||||
world_size = group.nranks | |||||
paddle.distributed.barrier(group=group) | |||||
# 张量为 标量的情况,简单地gather就好 | |||||
if result.ndim == 0: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 获得 result 的 shape | |||||
local_size = paddle.to_tensor(result.shape) | |||||
# 将 group 中所有 result 的大小聚合在一起 | |||||
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] | |||||
paddle.distributed.all_gather(local_sizes, local_size, group=group) | |||||
# 堆叠后,计算出 shape 每一维度的最大值 | |||||
max_size = paddle.stack(local_sizes).max(axis=0).values | |||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||||
# 如果所有的结果大小相同,那么可以直接聚合 | |||||
if all_sizes_equal: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 否则,padding 与最大的张量对齐 | |||||
pad_dims = [] | |||||
pad_by = (max_size - local_size).detach().cpu() | |||||
for val in reversed(pad_by): | |||||
pad_dims.append(0) | |||||
pad_dims.append(val.item()) | |||||
result_padded = paddle.nn.functional.pad(result, pad_dims) | |||||
# 重新进行聚合 | |||||
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] | |||||
paddle.distributed.all_gather(gathered_result, result_padded, group) | |||||
for idx, item_size in enumerate(local_sizes): | |||||
slice_param = [slice(dim_size) for dim_size in item_size] | |||||
gathered_result[idx] = gathered_result[idx][slice_param] | |||||
return gathered_result | |||||
def move_tensor_to_device(self, tensor, device): | |||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | |||||
if is_in_paddle_dist(): | |||||
device = get_device_from_visible(device) | |||||
return paddle_to(tensor, device) | |||||
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
'TorchBackend' | |||||
] | |||||
from .backend import Backend as TorchBackend |
@@ -0,0 +1,154 @@ | |||||
from typing import Any, List, Optional | |||||
import numpy as np | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.metrics.utils import AggregateMethodError | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
import torch.nn.functional as F | |||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)] | |||||
dist.all_gather(gathered_result, result, group) | |||||
return gathered_result | |||||
class TorchBackend(Backend): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._specified = True | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据method计算后,返回结果。 | |||||
""" | |||||
if isinstance(tensor, torch.Tensor): | |||||
if dist.is_initialized(): | |||||
if method is None: | |||||
raise AggregateMethodError(should_have_aggregate_method=True) | |||||
tensor = self._gather_all(tensor) | |||||
if isinstance(tensor[0], torch.Tensor): | |||||
tensor = torch.stack(tensor) | |||||
# 第一步, aggregate结果 | |||||
if method == 'sum': | |||||
tensor = torch.sum(tensor, dim=0) | |||||
elif method == 'mean': | |||||
tensor = torch.mean(tensor, dim=0) | |||||
elif method == 'max': | |||||
tensor, _ = torch.max(tensor, dim=0) | |||||
elif method == 'min': | |||||
tensor, _ = torch.min(tensor, dim=0) | |||||
else: | |||||
raise AggregateMethodError(should_have_aggregate_method=False) | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建tensor,并且填入value作为值 | |||||
""" | |||||
tensor = torch.ones(1).fill_(value) | |||||
return tensor | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将tensor的值设置为value | |||||
""" | |||||
tensor.fill_(value) | |||||
return tensor | |||||
def get_scalar(self, tensor) -> float: | |||||
return tensor.item() | |||||
@staticmethod | |||||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||||
Args: | |||||
result: the value to sync | |||||
group: the process group to gather results from. Defaults to all processes (world) | |||||
Return: | |||||
gathered_result: list with size equal to the process group where | |||||
gathered_result[i] corresponds to result tensor from process i | |||||
""" | |||||
if group is None: | |||||
group = dist.group.WORLD | |||||
# convert tensors to contiguous format | |||||
result = result.contiguous() | |||||
world_size = dist.get_world_size(group) | |||||
dist.barrier(group=group) | |||||
# if the tensor is scalar, things are easy | |||||
if result.ndim == 0: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 1. Gather sizes of all tensors | |||||
local_size = torch.tensor(result.shape, device=result.device) | |||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||||
dist.all_gather(local_sizes, local_size, group=group) | |||||
max_size = torch.stack(local_sizes).max(dim=0).values | |||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||||
# 2. If shapes are all the same, then do a simple gather: | |||||
if all_sizes_equal: | |||||
return _simple_gather_all_tensors(result, group, world_size) | |||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||||
pad_dims = [] | |||||
pad_by = (max_size - local_size).detach().cpu() | |||||
for val in reversed(pad_by): | |||||
pad_dims.append(0) | |||||
pad_dims.append(val.item()) | |||||
result_padded = torch.nn.functional.pad(result, pad_dims) | |||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||||
dist.all_gather(gathered_result, result_padded, group) | |||||
for idx, item_size in enumerate(local_sizes): | |||||
slice_param = [slice(dim_size) for dim_size in item_size] | |||||
gathered_result[idx] = gathered_result[idx][slice_param] | |||||
return gathered_result | |||||
def tensor2numpy(self, tensor) -> np.array: | |||||
""" | |||||
将对应的tensor转为numpy对象 | |||||
""" | |||||
if isinstance(tensor, torch.Tensor): | |||||
return tensor.cpu().detach().numpy() | |||||
elif isinstance(tensor, np.ndarray): | |||||
return tensor | |||||
elif isinstance(tensor, (float, int)): | |||||
return tensor | |||||
else: | |||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||||
@staticmethod | |||||
def is_distributed() -> bool: | |||||
""" | |||||
:return: | |||||
""" | |||||
return dist.is_available() and dist.is_initialized() | |||||
def move_tensor_to_device(self, tensor, device): | |||||
return tensor.to(device) | |||||
def all_gather_object(self, obj, group=None) -> List: | |||||
if self.is_distributed(): | |||||
obj_list = fastnlp_torch_all_gather(obj, group=group) | |||||
return obj_list | |||||
return [obj] | |||||
@@ -0,0 +1,142 @@ | |||||
__all__ = [ | |||||
'ClassifyFPreRecMetric' | |||||
] | |||||
from typing import Union, List | |||||
from collections import defaultdict | |||||
from functools import partial | |||||
import warnings | |||||
from .metric import Metric | |||||
from .backend import Backend | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.utils.utils import seq_len_to_mask | |||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
r""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
class ClassifyFPreRecMetric(Metric): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||||
only_gross: bool = True, f_type='micro', beta=1) -> None: | |||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||||
aggregate_when_get_metric=aggregate_when_get_metric) | |||||
if f_type not in ('micro', 'macro'): | |||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||||
self.ignore_labels = ignore_labels | |||||
self.f_type = f_type | |||||
self.beta = beta | |||||
self.beta_square = self.beta ** 2 | |||||
self.only_gross = only_gross | |||||
self.tag_vocab = tag_vocab | |||||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||||
defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||||
defaultdict(partial(self.register_element, aggregate_method='sum')) | |||||
def get_metric(self) -> dict: | |||||
r""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
:return dict evaluate_result: {"acc": float} | |||||
""" | |||||
evaluate_result = {} | |||||
if not self.only_gross or self.f_type == 'macro': | |||||
tags = set(self._fn.keys()) | |||||
tags.update(set(self._fp.keys())) | |||||
tags.update(set(self._tp.keys())) | |||||
f_sum = 0 | |||||
pre_sum = 0 | |||||
rec_sum = 0 | |||||
for tag in tags: | |||||
if self.tag_vocab is not None: | |||||
tag_name = self.tag_vocab.to_word(tag) | |||||
else: | |||||
tag_name = int(tag) | |||||
tp = self._tp[tag] | |||||
fn = self._fn[tag] | |||||
fp = self._fp[tag] | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||||
f_sum += f | |||||
pre_sum += pre | |||||
rec_sum += rec | |||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag_name) | |||||
pre_key = 'pre-{}'.format(tag_name) | |||||
rec_key = 'rec-{}'.format(tag_name) | |||||
evaluate_result[f_key] = f | |||||
evaluate_result[pre_key] = pre | |||||
evaluate_result[rec_key] = rec | |||||
if self.f_type == 'macro': | |||||
evaluate_result['f'] = f_sum / len(tags) | |||||
evaluate_result['pre'] = pre_sum / len(tags) | |||||
evaluate_result['rec'] = rec_sum / len(tags) | |||||
if self.f_type == 'micro': | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
sum(self._tp.values()), | |||||
sum(self._fn.values()), | |||||
sum(self._fp.values())) | |||||
evaluate_result['f'] = f | |||||
evaluate_result['pre'] = pre | |||||
evaluate_result['rec'] = rec | |||||
for key, value in evaluate_result.items(): | |||||
evaluate_result[key] = round(value, 6) | |||||
return evaluate_result | |||||
def update(self, pred, target, seq_len=None): | |||||
pred = self.tensor2numpy(pred) | |||||
target = self.tensor2numpy(target) | |||||
if seq_len is not None: | |||||
seq_len = self.tensor2numpy(seq_len) | |||||
if seq_len is not None and target.ndim > 1: | |||||
max_len = target.ndim[-1] | |||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||||
else: | |||||
masks = None | |||||
if pred.ndim == target.ndim: | |||||
if len(pred.flatten()) != len(target.flatten()): | |||||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||||
f" while target have element numbers:{len(pred.flatten())}, " | |||||
f"pred have element numbers: {len(target.flatten())}") | |||||
pass | |||||
elif len(pred.ndim) == len(target.ndim) + 1: | |||||
pred = pred.argmax(axis=-1) | |||||
if seq_len is None and len(target.ndim) > 1: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | |||||
raise RuntimeError(f"when pred have " | |||||
f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||||
f"{pred.ndim[:-1]}, got {target.ndim}.") | |||||
if masks is not None: | |||||
target = target * masks | |||||
pred = pred * masks | |||||
target_idxes = set(target.reshape(-1).tolist()) | |||||
for target_idx in target_idxes: | |||||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | |||||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | |||||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | |||||
@@ -0,0 +1,281 @@ | |||||
__all__ = [ | |||||
'Element' | |||||
] | |||||
import os | |||||
from .backend import Backend, AutoBackend | |||||
from fastNLP.core.log import logger | |||||
from .utils import AggregateMethodError | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
class Element: | |||||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||||
self.init_value = value | |||||
self.aggregate_method = aggregate_method | |||||
self.name = name | |||||
if backend == 'auto': | |||||
raise RuntimeError("You have to specify the backend.") | |||||
elif isinstance(backend, AutoBackend): | |||||
self.backend = backend | |||||
else: | |||||
self.backend = AutoBackend(backend) | |||||
if self.backend.is_specified(): | |||||
value = self.backend.create_tensor(self.init_value) | |||||
else: | |||||
value = None | |||||
self._value = value | |||||
self.device = None | |||||
def aggregate(self): | |||||
""" | |||||
自动aggregate对应的元素 | |||||
""" | |||||
try: | |||||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | |||||
except AggregateMethodError as e: | |||||
msg = 'If you see this message, please report a bug.' | |||||
if self.name and e.should_have_aggregate_method: | |||||
msg = f"Element:{self.name} has no specified `aggregate_method`." | |||||
elif e.should_have_aggregate_method: | |||||
msg = "Element has no specified `aggregate_method`." | |||||
elif self.name and not e.should_have_aggregate_method: | |||||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | |||||
f'aggregate_method:{self.aggregate_method}.' | |||||
elif not e.should_have_aggregate_method: | |||||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||||
f'aggregate_method:{self.aggregate_method}.' | |||||
if e.only_warn: | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
logger.warning(msg) | |||||
self._value = self.backend.aggregate(self._value, method=None) | |||||
else: | |||||
raise RuntimeError(msg) | |||||
def reset(self): | |||||
if self.backend.is_specified(): | |||||
self._value = self.backend.fill_value(self._value, self.init_value) | |||||
@property | |||||
def value(self): | |||||
return self._value | |||||
@value.setter | |||||
def value(self, value): | |||||
self._check_value_initialized() | |||||
self._value = value | |||||
@value.getter | |||||
def value(self): | |||||
self._check_value_initialized() | |||||
return self._value | |||||
def get_scalar(self) -> float: | |||||
return self.backend.get_scalar(self._value) | |||||
def fill_value(self, value): | |||||
self._value = self.backend.fill_value(self._value, value) | |||||
def to(self, device): | |||||
# device这里如何处理呢? | |||||
if self._value is not None: | |||||
self._value = self.backend.move_tensor_to_device(self._value, device) | |||||
self.device = device | |||||
def _check_value_initialized(self): | |||||
if self._value is None: | |||||
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ | |||||
f"initialization." | |||||
self._value = self.backend.create_tensor(self.init_value) | |||||
if self.device is not None: | |||||
self.to(device=self.device) | |||||
def _check_value_when_call(self): | |||||
if self.value is None: | |||||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||||
"element, or use it after it being used by the `Metric.compute()` method.") | |||||
def __add__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value += other.value | |||||
else: | |||||
self.value += other | |||||
return self | |||||
def __radd__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value += other.value | |||||
else: | |||||
self.value += other | |||||
return self | |||||
def __sub__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value -= other.value | |||||
else: | |||||
self.value -= other | |||||
return self | |||||
def __rsub__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value -= other.value | |||||
else: | |||||
self.value -= other | |||||
return self | |||||
def __mul__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value *= other.value | |||||
else: | |||||
self.value *= other | |||||
return self | |||||
def __imul__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value *= other.value | |||||
else: | |||||
self.value *= other | |||||
return self | |||||
def __floordiv__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value //= other.value | |||||
else: | |||||
self.value //= other | |||||
return self | |||||
def __rfloordiv__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value //= other.value | |||||
else: | |||||
self.value //= other | |||||
return self | |||||
def __truediv__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value /= other.value | |||||
else: | |||||
self.value /= other | |||||
return self | |||||
def __rtruediv__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value /= other.value | |||||
else: | |||||
self.value /= other | |||||
return self | |||||
def __mod__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value %= other.value | |||||
else: | |||||
self.value %= other | |||||
return self | |||||
def __rmod__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value /= other.value | |||||
else: | |||||
self.value /= other | |||||
return self | |||||
def __pow__(self, other, modulo=None): | |||||
self._check_value_when_call() | |||||
if modulo is None: | |||||
if isinstance(other, Element): | |||||
self.value **= other.value | |||||
else: | |||||
self.value **= other | |||||
else: | |||||
if isinstance(other, Element): | |||||
self.value = pow(self.value, other.value, modulo) | |||||
else: | |||||
self.value = pow(self.value, other, modulo) | |||||
return self | |||||
def __rpow__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
self.value **= other.value | |||||
else: | |||||
self.value **= other | |||||
return self | |||||
def __lt__(self, other) -> bool: | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value < other.value | |||||
else: | |||||
return self.value < other | |||||
def __le__(self, other) -> bool: | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value <= other.value | |||||
else: | |||||
return self.value <= other | |||||
def __eq__(self, other): | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value == other.value | |||||
else: | |||||
return self.value == other | |||||
def __ne__(self, other) -> bool: | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value != other.value | |||||
else: | |||||
return self.value != other | |||||
def __ge__(self, other) -> bool: | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value >= other.value | |||||
else: | |||||
return self.value >= other | |||||
def __gt__(self, other) -> bool: | |||||
self._check_value_when_call() | |||||
if isinstance(other, Element): | |||||
return self.value > other.value | |||||
else: | |||||
return self.value > other | |||||
def __str__(self): | |||||
return str(self.value) | |||||
def __repr__(self): | |||||
return str(self.value) | |||||
def __getattr__(self, item): | |||||
""" | |||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||||
:param item: | |||||
:return: | |||||
""" | |||||
try: | |||||
if self._value is None: | |||||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||||
"element, or use it after it being used by the `Metric.compute()` method.") | |||||
return getattr(self._value, item) | |||||
except AttributeError as e: | |||||
raise e |
@@ -0,0 +1,184 @@ | |||||
__all__ = [ | |||||
'Metric' | |||||
] | |||||
from abc import abstractmethod | |||||
from typing import Union | |||||
import functools | |||||
from contextlib import contextmanager | |||||
import numpy as np | |||||
from fastNLP.core.metrics.backend import Backend, AutoBackend | |||||
from fastNLP.core.metrics.element import Element | |||||
class Metric: | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||||
""" | |||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||||
当 backend 不支持分布式时,该参数无意义。 | |||||
""" | |||||
self.backend = AutoBackend(backend) | |||||
self._updated = False | |||||
self.get_metric = self._sync_get_metric(self.get_metric) | |||||
self.update = self._wrap_update(self.update) | |||||
self.reset = self._wrap_auto_reset_elements(self.reset) | |||||
self.aggregate_when_get_metric = aggregate_when_get_metric | |||||
self._cannot_change_element = False | |||||
self._elements = {} | |||||
@property | |||||
def elements(self) -> dict: | |||||
return self._elements | |||||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||||
""" | |||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | |||||
tensor 直接进行加减乘除计算即可。 | |||||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | |||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | |||||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | |||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 | |||||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | |||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | |||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | |||||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | |||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||||
到任何一种 tensor ,就默认使用 float 类型作为 element 。 | |||||
:return: 注册的 Element 对象 | |||||
""" | |||||
if backend == 'auto': | |||||
backend = self.backend | |||||
else: | |||||
backend = AutoBackend(backend) | |||||
# 当name为None,默认为变量取得变量名 | |||||
if name is None: | |||||
name = f'ele_var_{len(self._elements)}' | |||||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||||
self.elements[name] = element | |||||
setattr(self, name, element) | |||||
return element | |||||
def reset(self): | |||||
""" | |||||
如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。 | |||||
""" | |||||
pass | |||||
def _wrap_auto_reset_elements(self, reset): | |||||
@functools.wraps(reset) | |||||
def _wrap_reset(*args, **kwargs): | |||||
self._updated = False | |||||
for ele in self.elements.values(): | |||||
ele.reset() | |||||
reset(*args, **kwargs) | |||||
return _wrap_reset | |||||
def _sync_get_metric(self, get_metric): | |||||
@functools.wraps(get_metric) | |||||
def _wrap_get_metric(*args, **kwargs): | |||||
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \ | |||||
f"get_metric()." | |||||
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): | |||||
results = get_metric(*args, **kwargs) | |||||
return results | |||||
return _wrap_get_metric | |||||
def __setattr__(self, key, value): | |||||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||||
if key in self.elements and value is not self.elements[key]: | |||||
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") | |||||
object.__setattr__(self, key, value) | |||||
def _wrap_update(self, update): | |||||
@functools.wraps(update) | |||||
def _wrap_update(*args, **kwargs): | |||||
self.check_backend(*args, **kwargs) | |||||
self._cannot_change_element = True | |||||
self._updated = True | |||||
return update(*args, **kwargs) | |||||
return _wrap_update | |||||
def check_backend(self, *args, **kwargs): | |||||
if not self.backend.is_specified(): | |||||
_args = [] | |||||
for arg in args: | |||||
_args.append(arg) | |||||
for arg in kwargs.values(): | |||||
_args.append(arg) | |||||
self.backend.choose_real_backend(_args) | |||||
@contextmanager | |||||
def sync(self, recover=True, aggregate=False): | |||||
""" | |||||
在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的 | |||||
值恢复到计算前的值。 | |||||
""" | |||||
keep_value = {} | |||||
if aggregate: | |||||
for name, element in self.elements.items(): | |||||
# 保存过去的值 | |||||
keep_value[name] = element.get_scalar() | |||||
# 聚合结果 | |||||
element.aggregate() | |||||
yield | |||||
if recover and aggregate: | |||||
for name, element in self.elements.items(): | |||||
# 恢复结果 | |||||
if name in keep_value: | |||||
element.fill_value(value=keep_value.get(name)) | |||||
@abstractmethod | |||||
def update(self, *args, **kwargs): | |||||
raise NotImplementedError() | |||||
@abstractmethod | |||||
def get_metric(self) -> dict: | |||||
raise NotImplementedError() | |||||
def set_auto_aggregate_when_get_metric(self, flag: bool): | |||||
""" | |||||
设置是否在 get_metric 的时候自动 aggregate | |||||
""" | |||||
self.aggregate_when_get_metric = flag | |||||
def __getattr__(self, name: str) -> Element: | |||||
if 'elements' in self.__dict__: | |||||
elements = self.__dict__['elements'] | |||||
if name in elements: | |||||
return elements[name] | |||||
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) | |||||
def tensor2numpy(self, tensor) -> np.array: | |||||
""" | |||||
将tensor向量转为numpy类型变量 | |||||
:param tensor: | |||||
:return: | |||||
""" | |||||
return self.backend.tensor2numpy(tensor) | |||||
def to(self, device): | |||||
""" | |||||
将所有的 element 变量移动到 device 设备上 | |||||
:param device: | |||||
:return: | |||||
""" | |||||
for element in self.elements.values(): | |||||
element.to(device) |
@@ -0,0 +1,344 @@ | |||||
__all__ = [ | |||||
'SpanFPreRecMetric' | |||||
] | |||||
from typing import Union, List, Optional | |||||
import warnings | |||||
from collections import defaultdict | |||||
from functools import partial | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.metrics.metric import Metric | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||||
r""" | |||||
检查vocab中的tag是否与encoding_type是匹配的 | |||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||||
:param encoding_type: bio, bmes, bioes, bmeso | |||||
:return: | |||||
""" | |||||
tag_set = set() | |||||
unk_token = '<unk>' | |||||
pad_token = '<pad>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
unk_token = tag_vocab.unknown | |||||
pad_token = tag_vocab.padding | |||||
tag_vocab = tag_vocab.idx2word | |||||
for idx, tag in tag_vocab.items(): | |||||
if tag in (unk_token, pad_token): | |||||
continue | |||||
tag = tag[:1].lower() | |||||
tag_set.add(tag) | |||||
tags = encoding_type | |||||
for tag in tag_set: | |||||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||||
f"encoding_type." | |||||
tags = tags.replace(tag, '') # 删除该值 | |||||
if tags: # 如果不为空,说明出现了未使用的tag | |||||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||||
"encoding_type.") | |||||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | |||||
r""" | |||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||||
:return: | |||||
""" | |||||
tag_set = set() | |||||
unk_token = '<unk>' | |||||
pad_token = '<pad>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
unk_token = tag_vocab.unknown | |||||
pad_token = tag_vocab.padding | |||||
tag_vocab = tag_vocab.idx2word | |||||
for idx, tag in tag_vocab.items(): | |||||
if tag in (unk_token, pad_token): | |||||
continue | |||||
tag = tag[:1].lower() | |||||
tag_set.add(tag) | |||||
bmes_tag_set = set('bmes') | |||||
if tag_set == bmes_tag_set: | |||||
return 'bmes' | |||||
bio_tag_set = set('bio') | |||||
if tag_set == bio_tag_set: | |||||
return 'bio' | |||||
bmeso_tag_set = set('bmeso') | |||||
if tag_set == bmeso_tag_set: | |||||
return 'bmeso' | |||||
bioes_tag_set = set('bioes') | |||||
if tag_set == bioes_tag_set: | |||||
return 'bioes' | |||||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||||
def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
r""" | |||||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bmes_tag == 'o': | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def _bioes_tag_to_spans(tags, ignore_labels=None): | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bioes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bioes_tag, label = tag[:1], tag[2:] | |||||
if bioes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bioes_tag == 'o': | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bioes_tag = bioes_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (左闭右开区间) | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bio_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bio_tag, label = tag[:1], tag[2:] | |||||
if bio_tag == 'b': | |||||
spans.append((label, [idx, idx])) | |||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bio_tag == 'o': # o tag does not count | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bio_tag = bio_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
r""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
class SpanFPreRecMetric(Metric): | |||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, | |||||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', | |||||
beta=1, aggregate_when_get_metric: bool = True,) -> None: | |||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||||
if f_type not in ('micro', 'macro'): | |||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||||
if not isinstance(tag_vocab, Vocabulary): | |||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||||
if encoding_type: | |||||
encoding_type = encoding_type.lower() | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||||
self.encoding_type = encoding_type | |||||
else: | |||||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
if self.encoding_type == 'bmes': | |||||
self.tag_to_span_func = _bmes_tag_to_spans | |||||
elif self.encoding_type == 'bio': | |||||
self.tag_to_span_func = _bio_tag_to_spans | |||||
elif self.encoding_type == 'bmeso': | |||||
self.tag_to_span_func = _bmeso_tag_to_spans | |||||
elif self.encoding_type == 'bioes': | |||||
self.tag_to_span_func = _bioes_tag_to_spans | |||||
else: | |||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||||
self.ignore_labels = ignore_labels | |||||
self.f_type = f_type | |||||
self.beta = beta | |||||
self.beta_square = self.beta ** 2 | |||||
self.only_gross = only_gross | |||||
self.tag_vocab = tag_vocab | |||||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
def get_metric(self) -> dict: | |||||
evaluate_result = {} | |||||
if not self.only_gross or self.f_type == 'macro': | |||||
tags = set(self._false_negatives.keys()) | |||||
tags.update(set(self._false_positives.keys())) | |||||
tags.update(set(self._true_positives.keys())) | |||||
f_sum = 0 | |||||
pre_sum = 0 | |||||
rec_sum = 0 | |||||
for tag in tags: | |||||
tp = self._true_positives[tag].get_scalar() | |||||
fn = self._false_negatives[tag].get_scalar() | |||||
fp = self._false_positives[tag].get_scalar() | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||||
f_sum += f | |||||
pre_sum += pre | |||||
rec_sum += rec | |||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | |||||
pre_key = 'pre-{}'.format(tag) | |||||
rec_key = 'rec-{}'.format(tag) | |||||
evaluate_result[f_key] = f | |||||
evaluate_result[pre_key] = pre | |||||
evaluate_result[rec_key] = rec | |||||
if self.f_type == 'macro': | |||||
evaluate_result['f'] = f_sum / len(tags) | |||||
evaluate_result['pre'] = pre_sum / len(tags) | |||||
evaluate_result['rec'] = rec_sum / len(tags) | |||||
if self.f_type == 'micro': | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
sum(val.get_scalar() for val in self._true_positives.values()), | |||||
sum(val.get_scalar() for val in self._false_negatives.values()), | |||||
sum(val.get_scalar() for val in self._false_positives.values())) | |||||
evaluate_result['f'] = f | |||||
evaluate_result['pre'] = pre | |||||
evaluate_result['rec'] = rec | |||||
for key, value in evaluate_result.items(): | |||||
evaluate_result[key] = round(value, 6) | |||||
return evaluate_result | |||||
def update(self, pred, target, seq_len: Optional[List] = None) -> None: | |||||
r"""update函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||||
:param target: [batch, seq_len], 真实值 | |||||
:param seq_len: [batch] 文本长度标记 | |||||
:return: | |||||
""" | |||||
pred = self.tensor2numpy(pred) | |||||
target = self.tensor2numpy(target) | |||||
if pred.ndim == target.ndim and target.ndim == 2: | |||||
pass | |||||
elif pred.ndim == target.ndim + 1 and target.ndim == 2: | |||||
num_classes = pred.shape[-1] | |||||
pred = pred.argmax(axis=-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
else: | |||||
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or " | |||||
f"{pred.shape[:-1]}, got {target.ndim}.") | |||||
batch_size = pred.shape[0] | |||||
pred = pred.tolist() | |||||
target = target.tolist() | |||||
for i in range(batch_size): | |||||
pred_tags = pred[i][:int(seq_len[i])] | |||||
gold_tags = target[i][:int(seq_len[i])] | |||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||||
for span in pred_spans: | |||||
if span in gold_spans: | |||||
self._true_positives[span[0]] += 1 | |||||
gold_spans.remove(span) | |||||
else: | |||||
self._false_positives[span[0]] += 1 | |||||
for span in gold_spans: | |||||
self._false_negatives[span[0]] += 1 |
@@ -0,0 +1,91 @@ | |||||
__all__ = [ | |||||
'func_post_proc' | |||||
] | |||||
from typing import Any | |||||
from functools import wraps | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
from fastNLP.envs.utils import _module_available | |||||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | |||||
if _IS_ALLENNLP_AVAILABLE: | |||||
from allennlp.training.metrics import Metric as allennlp_Metric | |||||
if _NEED_IMPORT_PADDLE: | |||||
from paddle.metric import Metric as paddle_Metric | |||||
def _is_torchmetrics_metric(metric: Any) -> bool: | |||||
""" | |||||
检查输入的对象是否为torchmetrics对象 | |||||
:param metric: | |||||
:return: | |||||
""" | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
return isinstance(metric, torchmetrics_Metric) | |||||
else: | |||||
return False | |||||
def _is_allennlp_metric(metric: Any) -> bool: | |||||
""" | |||||
检查输入的对象是否为allennlp对象 | |||||
:param metric: | |||||
:return: | |||||
""" | |||||
if _IS_ALLENNLP_AVAILABLE: | |||||
return isinstance(metric, allennlp_Metric) | |||||
else: | |||||
return False | |||||
def _is_paddle_metric(metric: Any) -> bool: | |||||
""" | |||||
检查输入的对象是否为allennlp对象 | |||||
:param metric: | |||||
:return: | |||||
""" | |||||
if _NEED_IMPORT_PADDLE: | |||||
return isinstance(metric, paddle_Metric) | |||||
else: | |||||
return False | |||||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||||
""" | |||||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||||
:param metric: metric对象 | |||||
:param fn: 作用于 metric 的 accumulate 方法的返回值 | |||||
:param method_name: 一般来说,对于 | |||||
:return: metric | |||||
""" | |||||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||||
f"Parameter `metric` must have a {method_name} function." | |||||
assert callable(fn), "Parameter `fn` must be callable." | |||||
func = getattr(metric, method_name) | |||||
@wraps(func) | |||||
def wrap_method(*args, **kwargs): | |||||
res = func(*args, **kwargs) | |||||
return fn(res) | |||||
wrap_method.__wrapped_by_func_post_proc__ = True | |||||
setattr(metric, method_name, wrap_method) | |||||
return metric | |||||
class AggregateMethodError(BaseException): | |||||
def __init__(self, should_have_aggregate_method, only_warn=False): | |||||
super(AggregateMethodError, self).__init__(self) | |||||
self.should_have_aggregate_method = should_have_aggregate_method | |||||
self.only_warn = only_warn |