|
|
@@ -618,149 +618,6 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
return f, pre, rec |
|
|
|
|
|
|
|
|
|
|
|
class BMESF1PreRecMetric(MetricBase): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` |
|
|
|
|
|
|
|
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, |
|
|
|
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B |
|
|
|
|
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
| | next_B | next_M | next_E | next_S | end | |
|
|
|
+=======+=========+==========+==========+=========+=========+ |
|
|
|
| start | 合法 | next_M=B | next_E=S | 合法 | -- | |
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | |
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | |
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | |
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | |
|
|
|
+-------+---------+----------+----------+---------+---------+ |
|
|
|
|
|
|
|
举例: |
|
|
|
prediction为BSEMS,会被认为是SSSSS. |
|
|
|
|
|
|
|
本Metric不检验target的合法性,请务必保证target的合法性。 |
|
|
|
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 |
|
|
|
target形状为 (batch_size, max_len) |
|
|
|
seq_lens形状为 (batch_size, ) |
|
|
|
|
|
|
|
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 |
|
|
|
|
|
|
|
:param b_idx: int, Begin标签所对应的tag idx. |
|
|
|
:param m_idx: int, Middle标签所对应的tag idx. |
|
|
|
:param e_idx: int, End标签所对应的tag idx. |
|
|
|
:param s_idx: int, Single标签所对应的tag idx |
|
|
|
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
|
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
|
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
|
|
|
|
self.yt_wordnum = 0 |
|
|
|
self.yp_wordnum = 0 |
|
|
|
self.corr_num = 0 |
|
|
|
|
|
|
|
self.b_idx = b_idx |
|
|
|
self.m_idx = m_idx |
|
|
|
self.e_idx = e_idx |
|
|
|
self.s_idx = s_idx |
|
|
|
# 还原init处介绍的矩阵 |
|
|
|
self._valida_matrix = { |
|
|
|
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx |
|
|
|
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], |
|
|
|
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], |
|
|
|
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], |
|
|
|
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], |
|
|
|
} |
|
|
|
|
|
|
|
def _validate_tags(self, tags): |
|
|
|
""" |
|
|
|
给定一个tag的Tensor,返回合法tag |
|
|
|
|
|
|
|
:param torch.Tensor tags: [seq_len] |
|
|
|
:return: 返回修改为合法tag的list |
|
|
|
""" |
|
|
|
assert len(tags)!=0 |
|
|
|
assert isinstance(tags, torch.Tensor) and len(tags.size())==1 |
|
|
|
padded_tags = [-1, *tags.tolist(), -1] |
|
|
|
for idx in range(len(padded_tags)-1): |
|
|
|
cur_tag = padded_tags[idx] |
|
|
|
if cur_tag not in self._valida_matrix: |
|
|
|
cur_tag = self.s_idx |
|
|
|
if padded_tags[idx+1] not in self._valida_matrix: |
|
|
|
padded_tags[idx+1] = self.s_idx |
|
|
|
next_tag = padded_tags[idx+1] |
|
|
|
shift_tag = self._valida_matrix[cur_tag][next_tag] |
|
|
|
if shift_tag[0]!=-1: |
|
|
|
padded_tags[idx+shift_tag[0]] = shift_tag[1] |
|
|
|
|
|
|
|
return padded_tags[1:-1] |
|
|
|
|
|
|
|
def evaluate(self, pred, target, seq_len): |
|
|
|
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 |
|
|
|
|
|
|
|
:param pred: [batch, seq_len] 或者 [batch, seq_len, 4] |
|
|
|
:param target: [batch, seq_len] |
|
|
|
:param seq_len: [batch] |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not isinstance(pred, torch.Tensor): |
|
|
|
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(pred)}.") |
|
|
|
if not isinstance(target, torch.Tensor): |
|
|
|
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if not isinstance(seq_len, torch.Tensor): |
|
|
|
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(seq_len)}.") |
|
|
|
|
|
|
|
if pred.size() == target.size() and len(target.size()) == 2: |
|
|
|
pass |
|
|
|
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: |
|
|
|
pred = pred.argmax(dim=-1) |
|
|
|
else: |
|
|
|
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " |
|
|
|
f"size:{pred.size()}, target should have size: {pred.size()} or " |
|
|
|
f"{pred.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
for idx in range(len(pred)): |
|
|
|
target_tags = target[idx][:seq_len[idx]].tolist() |
|
|
|
pred_tags = pred[idx][:seq_len[idx]] |
|
|
|
pred_tags = self._validate_tags(pred_tags) |
|
|
|
start_idx = 0 |
|
|
|
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): |
|
|
|
if t_tag in (self.s_idx, self.e_idx): |
|
|
|
self.yt_wordnum += 1 |
|
|
|
corr_flag = True |
|
|
|
for i in range(start_idx, t_idx+1): |
|
|
|
if target_tags[i]!=pred_tags[i]: |
|
|
|
corr_flag = False |
|
|
|
if corr_flag: |
|
|
|
self.corr_num += 1 |
|
|
|
start_idx = t_idx + 1 |
|
|
|
if p_tag in (self.s_idx, self.e_idx): |
|
|
|
self.yp_wordnum += 1 |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" |
|
|
|
P = self.corr_num / (self.yp_wordnum + 1e-12) |
|
|
|
R = self.corr_num / (self.yt_wordnum + 1e-12) |
|
|
|
F = 2 * P * R / (P + R + 1e-12) |
|
|
|
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} |
|
|
|
if reset: |
|
|
|
self.yp_wordnum = 0 |
|
|
|
self.yt_wordnum = 0 |
|
|
|
self.corr_num = 0 |
|
|
|
return evaluate_result |
|
|
|
|
|
|
|
|
|
|
|
def _prepare_metrics(metrics): |
|
|
|
""" |
|
|
|