@@ -0,0 +1,18 @@ | |||
__all__ = [ | |||
"Metric", | |||
"Accuracy", | |||
'Backend', | |||
'AutoBackend', | |||
'PaddleBackend', | |||
'TorchBackend', | |||
'SpanFPreRecMetric', | |||
'ClassifyFPreRecMetric', | |||
'func_post_proc' | |||
] | |||
from .metric import Metric | |||
from .accuracy import Accuracy | |||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | |||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | |||
from .utils import func_post_proc |
@@ -0,0 +1,75 @@ | |||
__all__ = [ | |||
'Accuracy' | |||
] | |||
from typing import Union | |||
import warnings | |||
import numpy as np | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.utils.utils import seq_len_to_mask | |||
class Accuracy(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = True): | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
def get_metric(self) -> dict: | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||
如果mask也被传进来的话seq_len会被忽略. | |||
""" | |||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
seq_len = self.tensor2numpy(seq_len) | |||
if seq_len is not None and target.ndim > 1: | |||
max_len = target.shape[1] | |||
masks = seq_len_to_mask(seq_len, max_len) | |||
else: | |||
masks = None | |||
if pred.ndim == target.ndim: | |||
if np.prod(pred.shape) != np.prod(target.shape): | |||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||
f" while target have shape:{target.shape}, " | |||
f"pred have shape: {target.shape}") | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
self.total += masks.sum().item() | |||
self.correct += ((pred == target) * masks).sum().item() | |||
else: | |||
self.total += np.prod(list(pred.shape)).item() | |||
self.correct += (target == pred).sum().item() |
@@ -0,0 +1,12 @@ | |||
__all__ = [ | |||
'Backend', | |||
'AutoBackend', | |||
'TorchBackend', | |||
'PaddleBackend' | |||
] | |||
from .backend import Backend | |||
from .auto_backend import AutoBackend | |||
from .torch_backend.backend import TorchBackend | |||
from .paddle_backend.backend import PaddleBackend |
@@ -0,0 +1,75 @@ | |||
from typing import Union | |||
from .backend import Backend | |||
from .torch_backend.backend import TorchBackend | |||
from .paddle_backend.backend import PaddleBackend | |||
from .jittor_backend.backend import JittorBackend | |||
class AutoBackend(Backend): | |||
""" | |||
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的 | |||
""" | |||
def __init__(self, backend: Union[str, Backend, None]): | |||
super(AutoBackend, self).__init__() | |||
if backend != 'auto': | |||
self._convert_backend(backend) | |||
def _convert_backend(self, backend): | |||
""" | |||
将AutoBackend转换为合适的Backend对象 | |||
""" | |||
if isinstance(backend, Backend): | |||
self.__class__ = backend.__class__ | |||
# 如果是str,直接选择就好了 | |||
elif backend == 'torch': | |||
self.__class__ = TorchBackend | |||
elif backend == 'paddle': | |||
self.__class__ = PaddleBackend | |||
elif backend == 'jittor': | |||
self.__class__ = JittorBackend | |||
elif backend is None: | |||
# 不用做任何事情就可以初始化了 | |||
pass | |||
else: | |||
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.") | |||
self._specified = True | |||
def choose_real_backend(self, args): | |||
assert not self.is_specified(), "This method should not be called after backend has been specified. " \ | |||
"This must be a bug, please report." | |||
types = [] | |||
for arg in args: | |||
types.append(str(type(arg))) | |||
torch_types = [] | |||
jittor_types = [] | |||
paddle_types = [] | |||
for type_name in types: | |||
if 'torch' in type_name: | |||
torch_types.append(type_name) | |||
if 'paddle' in type_name: | |||
paddle_types.append(type_name) | |||
if 'jittor' in type_name: | |||
jittor_types.append(type_name) | |||
# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上 | |||
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||
backend = 'torch' | |||
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0: | |||
backend = 'jittor' | |||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0: | |||
backend = 'paddle' | |||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||
# 直接使用default的backend就好了 | |||
backend = None | |||
else: | |||
types = list(set(torch_types + jittor_types + paddle_types)) | |||
raise RuntimeError( | |||
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of " | |||
f"using backend=auto.") | |||
self._convert_backend(backend) |
@@ -0,0 +1,75 @@ | |||
from ..utils import AggregateMethodError | |||
class Backend: | |||
""" | |||
Backend 及其子类的所有方法都必须是无状态的。 | |||
""" | |||
def __init__(self): | |||
self._specified = False | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
if method is not None: | |||
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
return value | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
return value | |||
def get_scalar(self, tensor) -> float: | |||
""" | |||
tensor的saclar值 | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor | |||
def is_specified(self) -> bool: | |||
""" | |||
判断是否是某种框架的backend | |||
:return: | |||
""" | |||
return self._specified | |||
def tensor2numpy(self, tensor): | |||
""" | |||
将tensor转为numpy | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor | |||
def move_tensor_to_device(self, tensor, device): | |||
""" | |||
""" | |||
return tensor | |||
def all_gather_object(self, obj, group=None): | |||
""" | |||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||
:param obj: | |||
:param group: | |||
:return: | |||
""" | |||
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.") | |||
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,72 @@ | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.metrics.backend import Backend | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
class JittorBackend(Backend): | |||
def __init__(self): | |||
super(JittorBackend, self).__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
value = jittor.Var(value) | |||
return value | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
value = jittor.full_like(tensor, value) | |||
return value | |||
def get_scalar(self, tensor) -> float: | |||
""" | |||
tensor的saclar值 | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor.item() | |||
def is_specified(self) -> bool: | |||
""" | |||
判断是否是某种框架的backend | |||
:return: | |||
""" | |||
return self._specified | |||
def tensor2numpy(self, tensor): | |||
""" | |||
将tensor转为numpy | |||
:param tensor: | |||
:return: | |||
""" | |||
if isinstance(tensor, jittor.Var): | |||
return tensor.detach().numpy() | |||
elif isinstance(tensor, np.array): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
def move_tensor_to_device(self, tensor, device): | |||
""" | |||
jittor的没有转移设备的函数,因此该函数实际上无效 | |||
""" | |||
return tensor |
@@ -0,0 +1,5 @@ | |||
__all__ = [ | |||
'PaddleBackend' | |||
] | |||
from .backend import Backend as PaddleBackend |
@@ -0,0 +1,126 @@ | |||
from typing import List, Optional, Any | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.utils.paddle_utils import paddle_to | |||
from fastNLP.core.metrics.utils import AggregateMethodError | |||
from fastNLP.core.utils import is_in_paddle_dist | |||
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.fluid.dygraph import parallel_helper | |||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | |||
paddle.distributed.all_gather(gathered_result, result, group) | |||
return gathered_result | |||
class PaddleBackend(Backend): | |||
def __init__(self): | |||
super().__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
if isinstance(tensor, paddle.Tensor): | |||
if parallel_helper._is_parallel_ctx_initialized(): | |||
if method is None: | |||
raise AggregateMethodError(should_have_aggregate_method=True) | |||
tensor = self._gather_all(tensor) | |||
if isinstance(tensor[0], paddle.Tensor): | |||
tensor = paddle.stack(tensor) | |||
# 第一步, aggregate结果 | |||
if method == 'sum': | |||
tensor = paddle.sum(tensor, dim=0) | |||
elif method == 'mean': | |||
tensor = paddle.mean(tensor, dim=0) | |||
elif method == 'max': | |||
tensor, _ = paddle.max(tensor, dim=0) | |||
elif method == 'min': | |||
tensor, _ = paddle.min(tensor, dim=0) | |||
else: | |||
raise AggregateMethodError(should_have_aggregate_method=False) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
tensor = paddle.ones((1,)).fill_(value) | |||
return tensor | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
tensor.fill_(value) | |||
return tensor | |||
def get_scalar(self, tensor) -> float: | |||
return tensor.item() | |||
def tensor2numpy(self, tensor) -> np.array: | |||
if isinstance(tensor, paddle.Tensor): | |||
return tensor.cpu().detach().numpy() | |||
elif isinstance(tensor, np.array): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
@staticmethod | |||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||
""" | |||
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | |||
""" | |||
# TODO check 正确性 | |||
if group is None: | |||
group = paddle.distributed.get_group(0) | |||
world_size = group.nranks | |||
paddle.distributed.barrier(group=group) | |||
# 张量为 标量的情况,简单地gather就好 | |||
if result.ndim == 0: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 获得 result 的 shape | |||
local_size = paddle.to_tensor(result.shape) | |||
# 将 group 中所有 result 的大小聚合在一起 | |||
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] | |||
paddle.distributed.all_gather(local_sizes, local_size, group=group) | |||
# 堆叠后,计算出 shape 每一维度的最大值 | |||
max_size = paddle.stack(local_sizes).max(axis=0).values | |||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
# 如果所有的结果大小相同,那么可以直接聚合 | |||
if all_sizes_equal: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 否则,padding 与最大的张量对齐 | |||
pad_dims = [] | |||
pad_by = (max_size - local_size).detach().cpu() | |||
for val in reversed(pad_by): | |||
pad_dims.append(0) | |||
pad_dims.append(val.item()) | |||
result_padded = paddle.nn.functional.pad(result, pad_dims) | |||
# 重新进行聚合 | |||
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] | |||
paddle.distributed.all_gather(gathered_result, result_padded, group) | |||
for idx, item_size in enumerate(local_sizes): | |||
slice_param = [slice(dim_size) for dim_size in item_size] | |||
gathered_result[idx] = gathered_result[idx][slice_param] | |||
return gathered_result | |||
def move_tensor_to_device(self, tensor, device): | |||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | |||
if is_in_paddle_dist(): | |||
device = get_device_from_visible(device) | |||
return paddle_to(tensor, device) | |||
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
'TorchBackend' | |||
] | |||
from .backend import Backend as TorchBackend |
@@ -0,0 +1,154 @@ | |||
from typing import Any, List, Optional | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.utils import AggregateMethodError | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
import torch.nn.functional as F | |||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)] | |||
dist.all_gather(gathered_result, result, group) | |||
return gathered_result | |||
class TorchBackend(Backend): | |||
def __init__(self): | |||
super().__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果。 | |||
""" | |||
if isinstance(tensor, torch.Tensor): | |||
if dist.is_initialized(): | |||
if method is None: | |||
raise AggregateMethodError(should_have_aggregate_method=True) | |||
tensor = self._gather_all(tensor) | |||
if isinstance(tensor[0], torch.Tensor): | |||
tensor = torch.stack(tensor) | |||
# 第一步, aggregate结果 | |||
if method == 'sum': | |||
tensor = torch.sum(tensor, dim=0) | |||
elif method == 'mean': | |||
tensor = torch.mean(tensor, dim=0) | |||
elif method == 'max': | |||
tensor, _ = torch.max(tensor, dim=0) | |||
elif method == 'min': | |||
tensor, _ = torch.min(tensor, dim=0) | |||
else: | |||
raise AggregateMethodError(should_have_aggregate_method=False) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
tensor = torch.ones(1).fill_(value) | |||
return tensor | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
tensor.fill_(value) | |||
return tensor | |||
def get_scalar(self, tensor) -> float: | |||
return tensor.item() | |||
@staticmethod | |||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||
tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||
Args: | |||
result: the value to sync | |||
group: the process group to gather results from. Defaults to all processes (world) | |||
Return: | |||
gathered_result: list with size equal to the process group where | |||
gathered_result[i] corresponds to result tensor from process i | |||
""" | |||
if group is None: | |||
group = dist.group.WORLD | |||
# convert tensors to contiguous format | |||
result = result.contiguous() | |||
world_size = dist.get_world_size(group) | |||
dist.barrier(group=group) | |||
# if the tensor is scalar, things are easy | |||
if result.ndim == 0: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 1. Gather sizes of all tensors | |||
local_size = torch.tensor(result.shape, device=result.device) | |||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||
dist.all_gather(local_sizes, local_size, group=group) | |||
max_size = torch.stack(local_sizes).max(dim=0).values | |||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
# 2. If shapes are all the same, then do a simple gather: | |||
if all_sizes_equal: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||
pad_dims = [] | |||
pad_by = (max_size - local_size).detach().cpu() | |||
for val in reversed(pad_by): | |||
pad_dims.append(0) | |||
pad_dims.append(val.item()) | |||
result_padded = torch.nn.functional.pad(result, pad_dims) | |||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||
dist.all_gather(gathered_result, result_padded, group) | |||
for idx, item_size in enumerate(local_sizes): | |||
slice_param = [slice(dim_size) for dim_size in item_size] | |||
gathered_result[idx] = gathered_result[idx][slice_param] | |||
return gathered_result | |||
def tensor2numpy(self, tensor) -> np.array: | |||
""" | |||
将对应的tensor转为numpy对象 | |||
""" | |||
if isinstance(tensor, torch.Tensor): | |||
return tensor.cpu().detach().numpy() | |||
elif isinstance(tensor, np.ndarray): | |||
return tensor | |||
elif isinstance(tensor, (float, int)): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
@staticmethod | |||
def is_distributed() -> bool: | |||
""" | |||
:return: | |||
""" | |||
return dist.is_available() and dist.is_initialized() | |||
def move_tensor_to_device(self, tensor, device): | |||
return tensor.to(device) | |||
def all_gather_object(self, obj, group=None) -> List: | |||
if self.is_distributed(): | |||
obj_list = fastnlp_torch_all_gather(obj, group=group) | |||
return obj_list | |||
return [obj] | |||
@@ -0,0 +1,142 @@ | |||
__all__ = [ | |||
'ClassifyFPreRecMetric' | |||
] | |||
from typing import Union, List | |||
from collections import defaultdict | |||
from functools import partial | |||
import warnings | |||
from .metric import Metric | |||
from .backend import Backend | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.utils.utils import seq_len_to_mask | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', beta=1) -> None: | |||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||
aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta ** 2 | |||
self.only_gross = only_gross | |||
self.tag_vocab = tag_vocab | |||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')) | |||
def get_metric(self) -> dict: | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._fn.keys()) | |||
tags.update(set(self._fp.keys())) | |||
tags.update(set(self._tp.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
if self.tag_vocab is not None: | |||
tag_name = self.tag_vocab.to_word(tag) | |||
else: | |||
tag_name = int(tag) | |||
tp = self._tp[tag] | |||
fn = self._fn[tag] | |||
fp = self._fp[tag] | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag_name) | |||
pre_key = 'pre-{}'.format(tag_name) | |||
rec_key = 'rec-{}'.format(tag_name) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum / len(tags) | |||
evaluate_result['pre'] = pre_sum / len(tags) | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(self._tp.values()), | |||
sum(self._fn.values()), | |||
sum(self._fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
seq_len = self.tensor2numpy(seq_len) | |||
if seq_len is not None and target.ndim > 1: | |||
max_len = target.ndim[-1] | |||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||
else: | |||
masks = None | |||
if pred.ndim == target.ndim: | |||
if len(pred.flatten()) != len(target.flatten()): | |||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||
f" while target have element numbers:{len(pred.flatten())}, " | |||
f"pred have element numbers: {len(target.flatten())}") | |||
pass | |||
elif len(pred.ndim) == len(target.ndim) + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and len(target.ndim) > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have " | |||
f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.ndim[:-1]}, got {target.ndim}.") | |||
if masks is not None: | |||
target = target * masks | |||
pred = pred * masks | |||
target_idxes = set(target.reshape(-1).tolist()) | |||
for target_idx in target_idxes: | |||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | |||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | |||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | |||
@@ -0,0 +1,281 @@ | |||
__all__ = [ | |||
'Element' | |||
] | |||
import os | |||
from .backend import Backend, AutoBackend | |||
from fastNLP.core.log import logger | |||
from .utils import AggregateMethodError | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
class Element: | |||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||
self.init_value = value | |||
self.aggregate_method = aggregate_method | |||
self.name = name | |||
if backend == 'auto': | |||
raise RuntimeError("You have to specify the backend.") | |||
elif isinstance(backend, AutoBackend): | |||
self.backend = backend | |||
else: | |||
self.backend = AutoBackend(backend) | |||
if self.backend.is_specified(): | |||
value = self.backend.create_tensor(self.init_value) | |||
else: | |||
value = None | |||
self._value = value | |||
self.device = None | |||
def aggregate(self): | |||
""" | |||
自动aggregate对应的元素 | |||
""" | |||
try: | |||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | |||
except AggregateMethodError as e: | |||
msg = 'If you see this message, please report a bug.' | |||
if self.name and e.should_have_aggregate_method: | |||
msg = f"Element:{self.name} has no specified `aggregate_method`." | |||
elif e.should_have_aggregate_method: | |||
msg = "Element has no specified `aggregate_method`." | |||
elif self.name and not e.should_have_aggregate_method: | |||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
elif not e.should_have_aggregate_method: | |||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
if e.only_warn: | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
logger.warning(msg) | |||
self._value = self.backend.aggregate(self._value, method=None) | |||
else: | |||
raise RuntimeError(msg) | |||
def reset(self): | |||
if self.backend.is_specified(): | |||
self._value = self.backend.fill_value(self._value, self.init_value) | |||
@property | |||
def value(self): | |||
return self._value | |||
@value.setter | |||
def value(self, value): | |||
self._check_value_initialized() | |||
self._value = value | |||
@value.getter | |||
def value(self): | |||
self._check_value_initialized() | |||
return self._value | |||
def get_scalar(self) -> float: | |||
return self.backend.get_scalar(self._value) | |||
def fill_value(self, value): | |||
self._value = self.backend.fill_value(self._value, value) | |||
def to(self, device): | |||
# device这里如何处理呢? | |||
if self._value is not None: | |||
self._value = self.backend.move_tensor_to_device(self._value, device) | |||
self.device = device | |||
def _check_value_initialized(self): | |||
if self._value is None: | |||
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ | |||
f"initialization." | |||
self._value = self.backend.create_tensor(self.init_value) | |||
if self.device is not None: | |||
self.to(device=self.device) | |||
def _check_value_when_call(self): | |||
if self.value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
def __add__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value += other.value | |||
else: | |||
self.value += other | |||
return self | |||
def __radd__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value += other.value | |||
else: | |||
self.value += other | |||
return self | |||
def __sub__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value -= other.value | |||
else: | |||
self.value -= other | |||
return self | |||
def __rsub__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value -= other.value | |||
else: | |||
self.value -= other | |||
return self | |||
def __mul__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value *= other.value | |||
else: | |||
self.value *= other | |||
return self | |||
def __imul__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value *= other.value | |||
else: | |||
self.value *= other | |||
return self | |||
def __floordiv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value //= other.value | |||
else: | |||
self.value //= other | |||
return self | |||
def __rfloordiv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value //= other.value | |||
else: | |||
self.value //= other | |||
return self | |||
def __truediv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __rtruediv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __mod__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value %= other.value | |||
else: | |||
self.value %= other | |||
return self | |||
def __rmod__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __pow__(self, other, modulo=None): | |||
self._check_value_when_call() | |||
if modulo is None: | |||
if isinstance(other, Element): | |||
self.value **= other.value | |||
else: | |||
self.value **= other | |||
else: | |||
if isinstance(other, Element): | |||
self.value = pow(self.value, other.value, modulo) | |||
else: | |||
self.value = pow(self.value, other, modulo) | |||
return self | |||
def __rpow__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value **= other.value | |||
else: | |||
self.value **= other | |||
return self | |||
def __lt__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value < other.value | |||
else: | |||
return self.value < other | |||
def __le__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value <= other.value | |||
else: | |||
return self.value <= other | |||
def __eq__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value == other.value | |||
else: | |||
return self.value == other | |||
def __ne__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value != other.value | |||
else: | |||
return self.value != other | |||
def __ge__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value >= other.value | |||
else: | |||
return self.value >= other | |||
def __gt__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value > other.value | |||
else: | |||
return self.value > other | |||
def __str__(self): | |||
return str(self.value) | |||
def __repr__(self): | |||
return str(self.value) | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
if self._value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
return getattr(self._value, item) | |||
except AttributeError as e: | |||
raise e |
@@ -0,0 +1,184 @@ | |||
__all__ = [ | |||
'Metric' | |||
] | |||
from abc import abstractmethod | |||
from typing import Union | |||
import functools | |||
from contextlib import contextmanager | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend, AutoBackend | |||
from fastNLP.core.metrics.element import Element | |||
class Metric: | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
""" | |||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。 | |||
""" | |||
self.backend = AutoBackend(backend) | |||
self._updated = False | |||
self.get_metric = self._sync_get_metric(self.get_metric) | |||
self.update = self._wrap_update(self.update) | |||
self.reset = self._wrap_auto_reset_elements(self.reset) | |||
self.aggregate_when_get_metric = aggregate_when_get_metric | |||
self._cannot_change_element = False | |||
self._elements = {} | |||
@property | |||
def elements(self) -> dict: | |||
return self._elements | |||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
""" | |||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | |||
tensor 直接进行加减乘除计算即可。 | |||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | |||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | |||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | |||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 | |||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | |||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | |||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | |||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | |||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||
到任何一种 tensor ,就默认使用 float 类型作为 element 。 | |||
:return: 注册的 Element 对象 | |||
""" | |||
if backend == 'auto': | |||
backend = self.backend | |||
else: | |||
backend = AutoBackend(backend) | |||
# 当name为None,默认为变量取得变量名 | |||
if name is None: | |||
name = f'ele_var_{len(self._elements)}' | |||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||
self.elements[name] = element | |||
setattr(self, name, element) | |||
return element | |||
def reset(self): | |||
""" | |||
如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。 | |||
""" | |||
pass | |||
def _wrap_auto_reset_elements(self, reset): | |||
@functools.wraps(reset) | |||
def _wrap_reset(*args, **kwargs): | |||
self._updated = False | |||
for ele in self.elements.values(): | |||
ele.reset() | |||
reset(*args, **kwargs) | |||
return _wrap_reset | |||
def _sync_get_metric(self, get_metric): | |||
@functools.wraps(get_metric) | |||
def _wrap_get_metric(*args, **kwargs): | |||
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \ | |||
f"get_metric()." | |||
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): | |||
results = get_metric(*args, **kwargs) | |||
return results | |||
return _wrap_get_metric | |||
def __setattr__(self, key, value): | |||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||
if key in self.elements and value is not self.elements[key]: | |||
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") | |||
object.__setattr__(self, key, value) | |||
def _wrap_update(self, update): | |||
@functools.wraps(update) | |||
def _wrap_update(*args, **kwargs): | |||
self.check_backend(*args, **kwargs) | |||
self._cannot_change_element = True | |||
self._updated = True | |||
return update(*args, **kwargs) | |||
return _wrap_update | |||
def check_backend(self, *args, **kwargs): | |||
if not self.backend.is_specified(): | |||
_args = [] | |||
for arg in args: | |||
_args.append(arg) | |||
for arg in kwargs.values(): | |||
_args.append(arg) | |||
self.backend.choose_real_backend(_args) | |||
@contextmanager | |||
def sync(self, recover=True, aggregate=False): | |||
""" | |||
在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的 | |||
值恢复到计算前的值。 | |||
""" | |||
keep_value = {} | |||
if aggregate: | |||
for name, element in self.elements.items(): | |||
# 保存过去的值 | |||
keep_value[name] = element.get_scalar() | |||
# 聚合结果 | |||
element.aggregate() | |||
yield | |||
if recover and aggregate: | |||
for name, element in self.elements.items(): | |||
# 恢复结果 | |||
if name in keep_value: | |||
element.fill_value(value=keep_value.get(name)) | |||
@abstractmethod | |||
def update(self, *args, **kwargs): | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def get_metric(self) -> dict: | |||
raise NotImplementedError() | |||
def set_auto_aggregate_when_get_metric(self, flag: bool): | |||
""" | |||
设置是否在 get_metric 的时候自动 aggregate | |||
""" | |||
self.aggregate_when_get_metric = flag | |||
def __getattr__(self, name: str) -> Element: | |||
if 'elements' in self.__dict__: | |||
elements = self.__dict__['elements'] | |||
if name in elements: | |||
return elements[name] | |||
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) | |||
def tensor2numpy(self, tensor) -> np.array: | |||
""" | |||
将tensor向量转为numpy类型变量 | |||
:param tensor: | |||
:return: | |||
""" | |||
return self.backend.tensor2numpy(tensor) | |||
def to(self, device): | |||
""" | |||
将所有的 element 变量移动到 device 设备上 | |||
:param device: | |||
:return: | |||
""" | |||
for element in self.elements.values(): | |||
element.to(device) |
@@ -0,0 +1,344 @@ | |||
__all__ = [ | |||
'SpanFPreRecMetric' | |||
] | |||
from typing import Union, List, Optional | |||
import warnings | |||
from collections import defaultdict | |||
from functools import partial | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.vocabulary import Vocabulary | |||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||
r""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | |||
r""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
bmes_tag_set = set('bmes') | |||
if tag_set == bmes_tag_set: | |||
return 'bmes' | |||
bio_tag_set = set('bio') | |||
if tag_set == bio_tag_set: | |||
return 'bio' | |||
bmeso_tag_set = set('bmeso') | |||
if tag_set == bmeso_tag_set: | |||
return 'bmeso' | |||
bioes_tag_set = set('bioes') | |||
if tag_set == bioes_tag_set: | |||
return 'bioes' | |||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
def _bmes_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bmes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bioes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bioes_tag, label = tag[:1], tag[2:] | |||
if bioes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bioes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bioes_tag = bioes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bio_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bio_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bio_tag, label = tag[:1], tag[2:] | |||
if bio_tag == 'b': | |||
spans.append((label, [idx, idx])) | |||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bio_tag == 'o': # o tag does not count | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bio_tag = bio_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class SpanFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, | |||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', | |||
beta=1, aggregate_when_get_metric: bool = True,) -> None: | |||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if encoding_type: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
self.encoding_type = encoding_type | |||
else: | |||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = _bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = _bmeso_tag_to_spans | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta ** 2 | |||
self.only_gross = only_gross | |||
self.tag_vocab = tag_vocab | |||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
def get_metric(self) -> dict: | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(set(self._false_positives.keys())) | |||
tags.update(set(self._true_positives.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
tp = self._true_positives[tag].get_scalar() | |||
fn = self._false_negatives[tag].get_scalar() | |||
fp = self._false_positives[tag].get_scalar() | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
rec_key = 'rec-{}'.format(tag) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum / len(tags) | |||
evaluate_result['pre'] = pre_sum / len(tags) | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(val.get_scalar() for val in self._true_positives.values()), | |||
sum(val.get_scalar() for val in self._false_negatives.values()), | |||
sum(val.get_scalar() for val in self._false_positives.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len: Optional[List] = None) -> None: | |||
r"""update函数将针对一个批次的预测结果做评价指标的累计 | |||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||
:param target: [batch, seq_len], 真实值 | |||
:param seq_len: [batch] 文本长度标记 | |||
:return: | |||
""" | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if pred.ndim == target.ndim and target.ndim == 2: | |||
pass | |||
elif pred.ndim == target.ndim + 1 and target.ndim == 2: | |||
num_classes = pred.shape[-1] | |||
pred = pred.argmax(axis=-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.shape[:-1]}, got {target.ndim}.") | |||
batch_size = pred.shape[0] | |||
pred = pred.tolist() | |||
target = target.tolist() | |||
for i in range(batch_size): | |||
pred_tags = pred[i][:int(seq_len[i])] | |||
gold_tags = target[i][:int(seq_len[i])] | |||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||
for span in pred_spans: | |||
if span in gold_spans: | |||
self._true_positives[span[0]] += 1 | |||
gold_spans.remove(span) | |||
else: | |||
self._false_positives[span[0]] += 1 | |||
for span in gold_spans: | |||
self._false_negatives[span[0]] += 1 |
@@ -0,0 +1,91 @@ | |||
__all__ = [ | |||
'func_post_proc' | |||
] | |||
from typing import Any | |||
from functools import wraps | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.utils import _module_available | |||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | |||
if _IS_TORCHMETRICS_AVAILABLE: | |||
from torchmetrics import Metric as torchmetrics_Metric | |||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | |||
if _IS_ALLENNLP_AVAILABLE: | |||
from allennlp.training.metrics import Metric as allennlp_Metric | |||
if _NEED_IMPORT_PADDLE: | |||
from paddle.metric import Metric as paddle_Metric | |||
def _is_torchmetrics_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为torchmetrics对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _IS_TORCHMETRICS_AVAILABLE: | |||
return isinstance(metric, torchmetrics_Metric) | |||
else: | |||
return False | |||
def _is_allennlp_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为allennlp对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _IS_ALLENNLP_AVAILABLE: | |||
return isinstance(metric, allennlp_Metric) | |||
else: | |||
return False | |||
def _is_paddle_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为allennlp对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _NEED_IMPORT_PADDLE: | |||
return isinstance(metric, paddle_Metric) | |||
else: | |||
return False | |||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||
""" | |||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||
:param metric: metric对象 | |||
:param fn: 作用于 metric 的 accumulate 方法的返回值 | |||
:param method_name: 一般来说,对于 | |||
:return: metric | |||
""" | |||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||
f"Parameter `metric` must have a {method_name} function." | |||
assert callable(fn), "Parameter `fn` must be callable." | |||
func = getattr(metric, method_name) | |||
@wraps(func) | |||
def wrap_method(*args, **kwargs): | |||
res = func(*args, **kwargs) | |||
return fn(res) | |||
wrap_method.__wrapped_by_func_post_proc__ = True | |||
setattr(metric, method_name, wrap_method) | |||
return metric | |||
class AggregateMethodError(BaseException): | |||
def __init__(self, should_have_aggregate_method, only_warn=False): | |||
super(AggregateMethodError, self).__init__(self) | |||
self.should_have_aggregate_method = should_have_aggregate_method | |||
self.only_warn = only_warn |