Browse Source

添加了metrics

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
4bb56616b9
16 changed files with 1661 additions and 0 deletions
  1. +18
    -0
      fastNLP/core/metrics/__init__.py
  2. +75
    -0
      fastNLP/core/metrics/accuracy.py
  3. +12
    -0
      fastNLP/core/metrics/backend/__init__.py
  4. +75
    -0
      fastNLP/core/metrics/backend/auto_backend.py
  5. +75
    -0
      fastNLP/core/metrics/backend/backend.py
  6. +1
    -0
      fastNLP/core/metrics/backend/jittor_backend/__init__.py
  7. +72
    -0
      fastNLP/core/metrics/backend/jittor_backend/backend.py
  8. +5
    -0
      fastNLP/core/metrics/backend/paddle_backend/__init__.py
  9. +126
    -0
      fastNLP/core/metrics/backend/paddle_backend/backend.py
  10. +6
    -0
      fastNLP/core/metrics/backend/torch_backend/__init__.py
  11. +154
    -0
      fastNLP/core/metrics/backend/torch_backend/backend.py
  12. +142
    -0
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  13. +281
    -0
      fastNLP/core/metrics/element.py
  14. +184
    -0
      fastNLP/core/metrics/metric.py
  15. +344
    -0
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  16. +91
    -0
      fastNLP/core/metrics/utils.py

+ 18
- 0
fastNLP/core/metrics/__init__.py View File

@@ -0,0 +1,18 @@
__all__ = [
"Metric",
"Accuracy",
'Backend',
'AutoBackend',
'PaddleBackend',
'TorchBackend',
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',
'func_post_proc'
]

from .metric import Metric
from .accuracy import Accuracy
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
from .utils import func_post_proc

+ 75
- 0
fastNLP/core/metrics/accuracy.py View File

@@ -0,0 +1,75 @@
__all__ = [
'Accuracy'
]

from typing import Union
import warnings

import numpy as np

from fastNLP.core.metrics.metric import Metric
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.utils import seq_len_to_mask


class Accuracy(Metric):

def __init__(self, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = True):
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)

def get_metric(self) -> dict:
r"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)}
return evaluate_result

def update(self, pred, target, seq_len=None):
r"""
evaluate函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略.
"""
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if seq_len is not None:
seq_len = self.tensor2numpy(seq_len)

if seq_len is not None and target.ndim > 1:
max_len = target.shape[1]
masks = seq_len_to_mask(seq_len, max_len)
else:
masks = None

if pred.ndim == target.ndim:
if np.prod(pred.shape) != np.prod(target.shape):
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
f" while target have shape:{target.shape}, "
f"pred have shape: {target.shape}")

elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")

else:
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.")

if masks is not None:
self.total += masks.sum().item()
self.correct += ((pred == target) * masks).sum().item()
else:
self.total += np.prod(list(pred.shape)).item()
self.correct += (target == pred).sum().item()

+ 12
- 0
fastNLP/core/metrics/backend/__init__.py View File

@@ -0,0 +1,12 @@
__all__ = [
'Backend',
'AutoBackend',
'TorchBackend',
'PaddleBackend'
]


from .backend import Backend
from .auto_backend import AutoBackend
from .torch_backend.backend import TorchBackend
from .paddle_backend.backend import PaddleBackend

+ 75
- 0
fastNLP/core/metrics/backend/auto_backend.py View File

@@ -0,0 +1,75 @@
from typing import Union

from .backend import Backend
from .torch_backend.backend import TorchBackend
from .paddle_backend.backend import PaddleBackend
from .jittor_backend.backend import JittorBackend


class AutoBackend(Backend):
"""
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的

"""

def __init__(self, backend: Union[str, Backend, None]):
super(AutoBackend, self).__init__()
if backend != 'auto':
self._convert_backend(backend)

def _convert_backend(self, backend):
"""
将AutoBackend转换为合适的Backend对象

"""
if isinstance(backend, Backend):
self.__class__ = backend.__class__
# 如果是str,直接选择就好了
elif backend == 'torch':
self.__class__ = TorchBackend
elif backend == 'paddle':
self.__class__ = PaddleBackend
elif backend == 'jittor':
self.__class__ = JittorBackend
elif backend is None:
# 不用做任何事情就可以初始化了
pass
else:
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.")
self._specified = True

def choose_real_backend(self, args):
assert not self.is_specified(), "This method should not be called after backend has been specified. " \
"This must be a bug, please report."
types = []
for arg in args:
types.append(str(type(arg)))

torch_types = []
jittor_types = []
paddle_types = []
for type_name in types:
if 'torch' in type_name:
torch_types.append(type_name)
if 'paddle' in type_name:
paddle_types.append(type_name)
if 'jittor' in type_name:
jittor_types.append(type_name)

# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
backend = 'torch'
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0:
backend = 'jittor'
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0:
backend = 'paddle'
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
# 直接使用default的backend就好了
backend = None
else:
types = list(set(torch_types + jittor_types + paddle_types))
raise RuntimeError(
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of "
f"using backend=auto.")

self._convert_backend(backend)

+ 75
- 0
fastNLP/core/metrics/backend/backend.py View File

@@ -0,0 +1,75 @@
from ..utils import AggregateMethodError


class Backend:
"""
Backend 及其子类的所有方法都必须是无状态的。

"""

def __init__(self):
self._specified = False

def aggregate(self, tensor, method: str):
"""
聚集结果,并根据method计算后,返回结果
"""
if method is not None:
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True)

return tensor

def create_tensor(self, value: float):
"""
创建tensor,并且填入value作为值
"""
return value

def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value

"""
return value

def get_scalar(self, tensor) -> float:
"""
tensor的saclar值

:param tensor:
:return:
"""
return tensor

def is_specified(self) -> bool:
"""
判断是否是某种框架的backend

:return:
"""
return self._specified

def tensor2numpy(self, tensor):
"""
将tensor转为numpy

:param tensor:
:return:
"""
return tensor

def move_tensor_to_device(self, tensor, device):
"""
"""
return tensor

def all_gather_object(self, obj, group=None):
"""
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。

:param obj:
:param group:
:return:
"""
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.")


+ 1
- 0
fastNLP/core/metrics/backend/jittor_backend/__init__.py View File

@@ -0,0 +1 @@


+ 72
- 0
fastNLP/core/metrics/backend/jittor_backend/backend.py View File

@@ -0,0 +1,72 @@
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.metrics.backend import Backend

if _NEED_IMPORT_JITTOR:
import jittor


class JittorBackend(Backend):

def __init__(self):
super(JittorBackend, self).__init__()
self._specified = True

def aggregate(self, tensor, method: str):
"""
聚集结果,并根据method计算后,返回结果
"""
return tensor

def create_tensor(self, value: float):
"""
创建tensor,并且填入value作为值
"""
value = jittor.Var(value)
return value

def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value

"""
value = jittor.full_like(tensor, value)
return value

def get_scalar(self, tensor) -> float:
"""
tensor的saclar值

:param tensor:
:return:
"""
return tensor.item()

def is_specified(self) -> bool:
"""
判断是否是某种框架的backend

:return:
"""
return self._specified

def tensor2numpy(self, tensor):
"""
将tensor转为numpy

:param tensor:
:return:
"""
if isinstance(tensor, jittor.Var):
return tensor.detach().numpy()
elif isinstance(tensor, np.array):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")

def move_tensor_to_device(self, tensor, device):
"""
jittor的没有转移设备的函数,因此该函数实际上无效
"""
return tensor

+ 5
- 0
fastNLP/core/metrics/backend/paddle_backend/__init__.py View File

@@ -0,0 +1,5 @@
__all__ = [
'PaddleBackend'
]

from .backend import Backend as PaddleBackend

+ 126
- 0
fastNLP/core/metrics/backend/paddle_backend/backend.py View File

@@ -0,0 +1,126 @@
from typing import List, Optional, Any

import numpy as np

from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.paddle_utils import paddle_to
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.core.utils import is_in_paddle_dist
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
import paddle
from paddle.fluid.dygraph import parallel_helper

def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result, group)
return gathered_result

class PaddleBackend(Backend):
def __init__(self):
super().__init__()
self._specified = True

def aggregate(self, tensor, method: str):
"""
聚集结果,并根据method计算后,返回结果
"""
if isinstance(tensor, paddle.Tensor):
if parallel_helper._is_parallel_ctx_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
if isinstance(tensor[0], paddle.Tensor):
tensor = paddle.stack(tensor)
# 第一步, aggregate结果
if method == 'sum':
tensor = paddle.sum(tensor, dim=0)
elif method == 'mean':
tensor = paddle.mean(tensor, dim=0)
elif method == 'max':
tensor, _ = paddle.max(tensor, dim=0)
elif method == 'min':
tensor, _ = paddle.min(tensor, dim=0)
else:
raise AggregateMethodError(should_have_aggregate_method=False)

return tensor

def create_tensor(self, value: float):
"""
创建tensor,并且填入value作为值
"""
tensor = paddle.ones((1,)).fill_(value)
return tensor

def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value

"""
tensor.fill_(value)
return tensor

def get_scalar(self, tensor) -> float:
return tensor.item()

def tensor2numpy(self, tensor) -> np.array:
if isinstance(tensor, paddle.Tensor):
return tensor.cpu().detach().numpy()
elif isinstance(tensor, np.array):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")

@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding
"""
# TODO check 正确性
if group is None:
group = paddle.distributed.get_group(0)

world_size = group.nranks
paddle.distributed.barrier(group=group)

# 张量为 标量的情况,简单地gather就好
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 获得 result 的 shape
local_size = paddle.to_tensor(result.shape)
# 将 group 中所有 result 的大小聚合在一起
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)]
paddle.distributed.all_gather(local_sizes, local_size, group=group)
# 堆叠后,计算出 shape 每一维度的最大值
max_size = paddle.stack(local_sizes).max(axis=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

# 如果所有的结果大小相同,那么可以直接聚合
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 否则,padding 与最大的张量对齐
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = paddle.nn.functional.pad(result, pad_dims)
# 重新进行聚合
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result

def move_tensor_to_device(self, tensor, device):
# TODO 如果在这里处理的话,会不会在别的地方引起bug?
if is_in_paddle_dist():
device = get_device_from_visible(device)
return paddle_to(tensor, device)


+ 6
- 0
fastNLP/core/metrics/backend/torch_backend/__init__.py View File

@@ -0,0 +1,6 @@
__all__ = [
'TorchBackend'
]


from .backend import Backend as TorchBackend

+ 154
- 0
fastNLP/core/metrics/backend/torch_backend/backend.py View File

@@ -0,0 +1,154 @@
from typing import Any, List, Optional

import numpy as np

from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather


if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
import torch.nn.functional as F


def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
dist.all_gather(gathered_result, result, group)
return gathered_result


class TorchBackend(Backend):
def __init__(self):
super().__init__()
self._specified = True

def aggregate(self, tensor, method: str):
"""
聚集结果,并根据method计算后,返回结果。
"""
if isinstance(tensor, torch.Tensor):
if dist.is_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
if isinstance(tensor[0], torch.Tensor):
tensor = torch.stack(tensor)
# 第一步, aggregate结果
if method == 'sum':
tensor = torch.sum(tensor, dim=0)
elif method == 'mean':
tensor = torch.mean(tensor, dim=0)
elif method == 'max':
tensor, _ = torch.max(tensor, dim=0)
elif method == 'min':
tensor, _ = torch.min(tensor, dim=0)
else:
raise AggregateMethodError(should_have_aggregate_method=False)

return tensor

def create_tensor(self, value: float):
"""
创建tensor,并且填入value作为值
"""
tensor = torch.ones(1).fill_(value)
return tensor

def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value

"""
tensor.fill_(value)
return tensor

def get_scalar(self, tensor) -> float:
return tensor.item()

@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.

Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)

Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""

if group is None:
group = dist.group.WORLD

# convert tensors to contiguous format
result = result.contiguous()

world_size = dist.get_world_size(group)
dist.barrier(group=group)

# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = torch.nn.functional.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result

def tensor2numpy(self, tensor) -> np.array:
"""
将对应的tensor转为numpy对象

"""

if isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy()
elif isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, (float, int)):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")

@staticmethod
def is_distributed() -> bool:
"""
:return:
"""
return dist.is_available() and dist.is_initialized()

def move_tensor_to_device(self, tensor, device):
return tensor.to(device)

def all_gather_object(self, obj, group=None) -> List:
if self.is_distributed():
obj_list = fastnlp_torch_all_gather(obj, group=group)
return obj_list
return [obj]


+ 142
- 0
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -0,0 +1,142 @@
__all__ = [
'ClassifyFPreRecMetric'
]

from typing import Union, List
from collections import defaultdict
from functools import partial
import warnings

from .metric import Metric
from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.utils import seq_len_to_mask


def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

return f, pre, rec


class ClassifyFPreRecMetric(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False,
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro', beta=1) -> None:
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))

self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.only_gross = only_gross

self.tag_vocab = tag_vocab

self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum'))

def get_metric(self) -> dict:
r"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._fn.keys())
tags.update(set(self._fp.keys()))
tags.update(set(self._tp.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
if self.tag_vocab is not None:
tag_name = self.tag_vocab.to_word(tag)
else:
tag_name = int(tag)
tp = self._tp[tag]
fn = self._fn[tag]
fp = self._fp[tag]
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag_name)
pre_key = 'pre-{}'.format(tag_name)
rec_key = 'rec-{}'.format(tag_name)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec

if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(self._tp.values()),
sum(self._fn.values()),
sum(self._fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec


for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)

return evaluate_result

def update(self, pred, target, seq_len=None):
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if seq_len is not None:
seq_len = self.tensor2numpy(seq_len)

if seq_len is not None and target.ndim > 1:
max_len = target.ndim[-1]
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
else:
masks = None

if pred.ndim == target.ndim:
if len(pred.flatten()) != len(target.flatten()):
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
f" while target have element numbers:{len(pred.flatten())}, "
f"pred have element numbers: {len(target.flatten())}")

pass
elif len(pred.ndim) == len(target.ndim) + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and len(target.ndim) > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"when pred have "
f"size:{pred.ndim}, target should have size: {pred.ndim} or "
f"{pred.ndim[:-1]}, got {target.ndim}.")
if masks is not None:
target = target * masks
pred = pred * masks
target_idxes = set(target.reshape(-1).tolist())
for target_idx in target_idxes:
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()



+ 281
- 0
fastNLP/core/metrics/element.py View File

@@ -0,0 +1,281 @@
__all__ = [
'Element'
]

import os

from .backend import Backend, AutoBackend
from fastNLP.core.log import logger
from .utils import AggregateMethodError
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK


class Element:
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
self.init_value = value
self.aggregate_method = aggregate_method
self.name = name
if backend == 'auto':
raise RuntimeError("You have to specify the backend.")
elif isinstance(backend, AutoBackend):
self.backend = backend
else:
self.backend = AutoBackend(backend)

if self.backend.is_specified():
value = self.backend.create_tensor(self.init_value)
else:
value = None
self._value = value
self.device = None

def aggregate(self):
"""
自动aggregate对应的元素

"""
try:
self._value = self.backend.aggregate(self._value, self.aggregate_method)
except AggregateMethodError as e:
msg = 'If you see this message, please report a bug.'
if self.name and e.should_have_aggregate_method:
msg = f"Element:{self.name} has no specified `aggregate_method`."
elif e.should_have_aggregate_method:
msg = "Element has no specified `aggregate_method`."
elif self.name and not e.should_have_aggregate_method:
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
elif not e.should_have_aggregate_method:
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
if e.only_warn:
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
logger.warning(msg)
self._value = self.backend.aggregate(self._value, method=None)
else:
raise RuntimeError(msg)

def reset(self):
if self.backend.is_specified():
self._value = self.backend.fill_value(self._value, self.init_value)

@property
def value(self):
return self._value

@value.setter
def value(self, value):
self._check_value_initialized()
self._value = value

@value.getter
def value(self):
self._check_value_initialized()
return self._value

def get_scalar(self) -> float:
return self.backend.get_scalar(self._value)

def fill_value(self, value):
self._value = self.backend.fill_value(self._value, value)

def to(self, device):
# device这里如何处理呢?
if self._value is not None:
self._value = self.backend.move_tensor_to_device(self._value, device)
self.device = device

def _check_value_initialized(self):
if self._value is None:
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \
f"initialization."
self._value = self.backend.create_tensor(self.init_value)
if self.device is not None:
self.to(device=self.device)

def _check_value_when_call(self):
if self.value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")

def __add__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self

def __radd__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self

def __sub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self

def __rsub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self

def __mul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self

def __imul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self

def __floordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self

def __rfloordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self

def __truediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self

def __rtruediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self

def __mod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value %= other.value
else:
self.value %= other
return self

def __rmod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self

def __pow__(self, other, modulo=None):
self._check_value_when_call()
if modulo is None:
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
else:
if isinstance(other, Element):
self.value = pow(self.value, other.value, modulo)
else:
self.value = pow(self.value, other, modulo)
return self

def __rpow__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
return self

def __lt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value < other.value
else:
return self.value < other

def __le__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value <= other.value
else:
return self.value <= other

def __eq__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
return self.value == other.value
else:
return self.value == other

def __ne__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value != other.value
else:
return self.value != other

def __ge__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value >= other.value
else:
return self.value >= other

def __gt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value > other.value
else:
return self.value > other

def __str__(self):
return str(self.value)

def __repr__(self):
return str(self.value)

def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法
:param item:
:return:
"""
try:
if self._value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
return getattr(self._value, item)
except AttributeError as e:
raise e

+ 184
- 0
fastNLP/core/metrics/metric.py View File

@@ -0,0 +1,184 @@
__all__ = [
'Metric'
]

from abc import abstractmethod

from typing import Union
import functools
from contextlib import contextmanager
import numpy as np

from fastNLP.core.metrics.backend import Backend, AutoBackend
from fastNLP.core.metrics.element import Element


class Metric:
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
"""

:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
当 backend 不支持分布式时,该参数无意义。
"""
self.backend = AutoBackend(backend)
self._updated = False
self.get_metric = self._sync_get_metric(self.get_metric)
self.update = self._wrap_update(self.update)
self.reset = self._wrap_auto_reset_elements(self.reset)
self.aggregate_when_get_metric = aggregate_when_get_metric
self._cannot_change_element = False
self._elements = {}

@property
def elements(self) -> dict:
return self._elements

def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
"""
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的
tensor 直接进行加减乘除计算即可。
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。

:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测
到任何一种 tensor ,就默认使用 float 类型作为 element 。
:return: 注册的 Element 对象
"""
if backend == 'auto':
backend = self.backend
else:
backend = AutoBackend(backend)

# 当name为None,默认为变量取得变量名
if name is None:
name = f'ele_var_{len(self._elements)}'

element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
self.elements[name] = element
setattr(self, name, element)
return element

def reset(self):
"""
如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。

"""
pass

def _wrap_auto_reset_elements(self, reset):
@functools.wraps(reset)
def _wrap_reset(*args, **kwargs):
self._updated = False
for ele in self.elements.values():
ele.reset()
reset(*args, **kwargs)

return _wrap_reset

def _sync_get_metric(self, get_metric):
@functools.wraps(get_metric)
def _wrap_get_metric(*args, **kwargs):
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \
f"get_metric()."
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
results = get_metric(*args, **kwargs)
return results

return _wrap_get_metric

def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if key in self.elements and value is not self.elements[key]:
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}")
object.__setattr__(self, key, value)

def _wrap_update(self, update):
@functools.wraps(update)
def _wrap_update(*args, **kwargs):
self.check_backend(*args, **kwargs)
self._cannot_change_element = True
self._updated = True
return update(*args, **kwargs)

return _wrap_update

def check_backend(self, *args, **kwargs):
if not self.backend.is_specified():
_args = []
for arg in args:
_args.append(arg)
for arg in kwargs.values():
_args.append(arg)
self.backend.choose_real_backend(_args)

@contextmanager
def sync(self, recover=True, aggregate=False):
"""
在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的
值恢复到计算前的值。

"""
keep_value = {}
if aggregate:
for name, element in self.elements.items():
# 保存过去的值
keep_value[name] = element.get_scalar()
# 聚合结果
element.aggregate()

yield

if recover and aggregate:
for name, element in self.elements.items():
# 恢复结果
if name in keep_value:
element.fill_value(value=keep_value.get(name))

@abstractmethod
def update(self, *args, **kwargs):
raise NotImplementedError()

@abstractmethod
def get_metric(self) -> dict:
raise NotImplementedError()

def set_auto_aggregate_when_get_metric(self, flag: bool):
"""
设置是否在 get_metric 的时候自动 aggregate

"""
self.aggregate_when_get_metric = flag

def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name))

def tensor2numpy(self, tensor) -> np.array:
"""
将tensor向量转为numpy类型变量

:param tensor:
:return:
"""
return self.backend.tensor2numpy(tensor)

def to(self, device):
"""
将所有的 element 变量移动到 device 设备上

:param device:
:return:
"""
for element in self.elements.values():
element.to(device)

+ 344
- 0
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -0,0 +1,344 @@
__all__ = [
'SpanFPreRecMetric'
]

from typing import Union, List, Optional
import warnings
from collections import defaultdict
from functools import partial

from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary


def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
r"""
检查vocab中的tag是否与encoding_type是匹配的

:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
:param encoding_type: bio, bmes, bioes, bmeso
:return:
"""
tag_set = set()
unk_token = '<unk>'
pad_token = '<pad>'
if isinstance(tag_vocab, Vocabulary):
unk_token = tag_vocab.unknown
pad_token = tag_vocab.padding
tag_vocab = tag_vocab.idx2word
for idx, tag in tag_vocab.items():
if tag in (unk_token, pad_token):
continue
tag = tag[:1].lower()
tag_set.add(tag)

tags = encoding_type
for tag in tag_set:
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
f"encoding_type."
tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空,说明出现了未使用的tag
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.")


def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str:
r"""
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio

:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
:return:
"""
tag_set = set()
unk_token = '<unk>'
pad_token = '<pad>'
if isinstance(tag_vocab, Vocabulary):
unk_token = tag_vocab.unknown
pad_token = tag_vocab.padding
tag_vocab = tag_vocab.idx2word
for idx, tag in tag_vocab.items():
if tag in (unk_token, pad_token):
continue
tag = tag[:1].lower()
tag_set.add(tag)

bmes_tag_set = set('bmes')
if tag_set == bmes_tag_set:
return 'bmes'
bio_tag_set = set('bio')
if tag_set == bio_tag_set:
return 'bio'
bmeso_tag_set = set('bmeso')
if tag_set == bmeso_tag_set:
return 'bmeso'
bioes_tag_set = set('bioes')
if tag_set == bioes_tag_set:
return 'bioes'
raise RuntimeError("encoding_type cannot be inferred automatically. Only support "
"'bio', 'bmes', 'bmeso', 'bioes' type.")


def _bmes_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]


def _bmeso_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]


def _bioes_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bioes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bioes_tag, label = tag[:1], tag[2:]
if bioes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bioes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bioes_tag = bioes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]


def _bio_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bio_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b':
spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
pass
else:
spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]


def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

return f, pre, rec


class SpanFPreRecMetric(Metric):

def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None,
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro',
beta=1, aggregate_when_get_metric: bool = True,) -> None:
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if encoding_type:
encoding_type = encoding_type.lower()
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
self.encoding_type = encoding_type
else:
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)

if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = _bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans
elif self.encoding_type == 'bioes':
self.tag_to_span_func = _bioes_tag_to_spans
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.")

self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.only_gross = only_gross
self.tag_vocab = tag_vocab

self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))

def get_metric(self) -> dict:
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec

if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(val.get_scalar() for val in self._true_positives.values()),
sum(val.get_scalar() for val in self._false_negatives.values()),
sum(val.get_scalar() for val in self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec

for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)

return evaluate_result

def update(self, pred, target, seq_len: Optional[List] = None) -> None:
r"""update函数将针对一个批次的预测结果做评价指标的累计

:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
:param target: [batch, seq_len], 真实值
:param seq_len: [batch] 文本长度标记
:return:
"""
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)

if pred.ndim == target.ndim and target.ndim == 2:
pass

elif pred.ndim == target.ndim + 1 and target.ndim == 2:
num_classes = pred.shape[-1]
pred = pred.argmax(axis=-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or "
f"{pred.shape[:-1]}, got {target.ndim}.")

batch_size = pred.shape[0]
pred = pred.tolist()
target = target.tolist()
for i in range(batch_size):
pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])]

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]

pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)

for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1

+ 91
- 0
fastNLP/core/metrics/utils.py View File

@@ -0,0 +1,91 @@
__all__ = [
'func_post_proc'
]

from typing import Any
from functools import wraps
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.utils import _module_available

_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

_IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
if _IS_ALLENNLP_AVAILABLE:
from allennlp.training.metrics import Metric as allennlp_Metric

if _NEED_IMPORT_PADDLE:
from paddle.metric import Metric as paddle_Metric


def _is_torchmetrics_metric(metric: Any) -> bool:
"""
检查输入的对象是否为torchmetrics对象

:param metric:
:return:
"""
if _IS_TORCHMETRICS_AVAILABLE:
return isinstance(metric, torchmetrics_Metric)
else:
return False


def _is_allennlp_metric(metric: Any) -> bool:
"""
检查输入的对象是否为allennlp对象

:param metric:
:return:
"""
if _IS_ALLENNLP_AVAILABLE:
return isinstance(metric, allennlp_Metric)
else:
return False


def _is_paddle_metric(metric: Any) -> bool:
"""
检查输入的对象是否为allennlp对象

:param metric:
:return:
"""
if _NEED_IMPORT_PADDLE:
return isinstance(metric, paddle_Metric)
else:
return False


def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
"""
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。

:param metric: metric对象
:param fn: 作用于 metric 的 accumulate 方法的返回值
:param method_name: 一般来说,对于
:return: metric
"""
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
f"Parameter `metric` must have a {method_name} function."
assert callable(fn), "Parameter `fn` must be callable."

func = getattr(metric, method_name)

@wraps(func)
def wrap_method(*args, **kwargs):
res = func(*args, **kwargs)
return fn(res)

wrap_method.__wrapped_by_func_post_proc__ = True
setattr(metric, method_name, wrap_method)
return metric


class AggregateMethodError(BaseException):
def __init__(self, should_have_aggregate_method, only_warn=False):
super(AggregateMethodError, self).__init__(self)
self.should_have_aggregate_method = should_have_aggregate_method
self.only_warn = only_warn

Loading…
Cancel
Save