@@ -206,7 +206,7 @@ class POS(API): | |||||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | ||||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | ||||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | ||||
seq_lens="word_seq_origin_len") | |||||
seq_len="word_seq_origin_len") | |||||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | ||||
{"truth": torch.Tensor(truth)}) | {"truth": torch.Tensor(truth)}) | ||||
test_result = evaluator.get_metric() | test_result = evaluator.get_metric() | ||||
@@ -300,15 +300,15 @@ class CWS(API): | |||||
pp(te_dataset) | pp(te_dataset) | ||||
from ..core.tester import Tester | from ..core.tester import Tester | ||||
from ..core.metrics import BMESF1PreRecMetric | |||||
from ..core.metrics import SpanFPreRecMetric | |||||
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, | |||||
tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64, | |||||
verbose=0) | verbose=0) | ||||
eval_res = tester.test() | eval_res = tester.test() | ||||
f1 = eval_res['BMESF1PreRecMetric']['f'] | |||||
pre = eval_res['BMESF1PreRecMetric']['pre'] | |||||
rec = eval_res['BMESF1PreRecMetric']['rec'] | |||||
f1 = eval_res['SpanFPreRecMetric']['f'] | |||||
pre = eval_res['SpanFPreRecMetric']['pre'] | |||||
rec = eval_res['SpanFPreRecMetric']['rec'] | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | ||||
return {"F1": f1, "precision": pre, "recall": rec} | return {"F1": f1, "precision": pre, "recall": rec} | ||||
@@ -20,7 +20,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||||
from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -618,149 +618,6 @@ class SpanFPreRecMetric(MetricBase): | |||||
return f, pre, rec | return f, pre, rec | ||||
class BMESF1PreRecMetric(MetricBase): | |||||
""" | |||||
别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` | |||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| | next_B | next_M | next_E | next_S | end | | |||||
+=======+=========+==========+==========+=========+=========+ | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | -- | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
+-------+---------+----------+----------+---------+---------+ | |||||
举例: | |||||
prediction为BSEMS,会被认为是SSSSS. | |||||
本Metric不检验target的合法性,请务必保证target的合法性。 | |||||
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 | |||||
target形状为 (batch_size, max_len) | |||||
seq_lens形状为 (batch_size, ) | |||||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 | |||||
""" | |||||
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||||
self.yt_wordnum = 0 | |||||
self.yp_wordnum = 0 | |||||
self.corr_num = 0 | |||||
self.b_idx = b_idx | |||||
self.m_idx = m_idx | |||||
self.e_idx = e_idx | |||||
self.s_idx = s_idx | |||||
# 还原init处介绍的矩阵 | |||||
self._valida_matrix = { | |||||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
} | |||||
def _validate_tags(self, tags): | |||||
""" | |||||
给定一个tag的Tensor,返回合法tag | |||||
:param torch.Tensor tags: [seq_len] | |||||
:return: 返回修改为合法tag的list | |||||
""" | |||||
assert len(tags)!=0 | |||||
assert isinstance(tags, torch.Tensor) and len(tags.size())==1 | |||||
padded_tags = [-1, *tags.tolist(), -1] | |||||
for idx in range(len(padded_tags)-1): | |||||
cur_tag = padded_tags[idx] | |||||
if cur_tag not in self._valida_matrix: | |||||
cur_tag = self.s_idx | |||||
if padded_tags[idx+1] not in self._valida_matrix: | |||||
padded_tags[idx+1] = self.s_idx | |||||
next_tag = padded_tags[idx+1] | |||||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||||
if shift_tag[0]!=-1: | |||||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||||
return padded_tags[1:-1] | |||||
def evaluate(self, pred, target, seq_len): | |||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, 4] | |||||
:param target: [batch, seq_len] | |||||
:param seq_len: [batch] | |||||
:return: | |||||
""" | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_len)}.") | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
for idx in range(len(pred)): | |||||
target_tags = target[idx][:seq_len[idx]].tolist() | |||||
pred_tags = pred[idx][:seq_len[idx]] | |||||
pred_tags = self._validate_tags(pred_tags) | |||||
start_idx = 0 | |||||
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): | |||||
if t_tag in (self.s_idx, self.e_idx): | |||||
self.yt_wordnum += 1 | |||||
corr_flag = True | |||||
for i in range(start_idx, t_idx+1): | |||||
if target_tags[i]!=pred_tags[i]: | |||||
corr_flag = False | |||||
if corr_flag: | |||||
self.corr_num += 1 | |||||
start_idx = t_idx + 1 | |||||
if p_tag in (self.s_idx, self.e_idx): | |||||
self.yp_wordnum += 1 | |||||
def get_metric(self, reset=True): | |||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||||
P = self.corr_num / (self.yp_wordnum + 1e-12) | |||||
R = self.corr_num / (self.yt_wordnum + 1e-12) | |||||
F = 2 * P * R / (P + R + 1e-12) | |||||
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} | |||||
if reset: | |||||
self.yp_wordnum = 0 | |||||
self.yt_wordnum = 0 | |||||
self.corr_num = 0 | |||||
return evaluate_result | |||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -4,12 +4,12 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
""" | """ | ||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||||
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | :param str encoding_type: 支持"bio", "bmes", "bmeso"。 | ||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
@@ -17,12 +17,12 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||||
为False, 返回的结果中不含与开始结尾相关的内容 | 为False, 返回的结果中不含与开始结尾相关的内容 | ||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | ||||
""" | """ | ||||
num_tags = len(id2label) | |||||
num_tags = len(id2target) | |||||
start_idx = num_tags | start_idx = num_tags | ||||
end_idx = num_tags + 1 | end_idx = num_tags + 1 | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
allowed_trans = [] | allowed_trans = [] | ||||
id_label_lst = list(id2label.items()) | |||||
id_label_lst = list(id2target.items()) | |||||
if include_start_end: | if include_start_end: | ||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | ||||
def split_tag_label(from_label): | def split_tag_label(from_label): | ||||
@@ -160,7 +160,7 @@ class ConditionalRandomField(nn.Module): | |||||
if allowed_transitions is None: | if allowed_transitions is None: | ||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | constrain = torch.zeros(num_tags + 2, num_tags + 2) | ||||
else: | else: | ||||
constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||||
constrain = torch.full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||||
for from_tag_id, to_tag_id in allowed_transitions: | for from_tag_id, to_tag_id in allowed_transitions: | ||||
constrain[from_tag_id, to_tag_id] = 0 | constrain[from_tag_id, to_tag_id] = 0 | ||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | self._constrain = nn.Parameter(constrain, requires_grad=False) | ||||
@@ -351,31 +351,6 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
# fastnlp_bmes_metric.get_metric()) | # fastnlp_bmes_metric.get_metric()) | ||||
class TestBMESF1PreRecMetric(unittest.TestCase): | |||||
def test_case1(self): | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
pred = torch.randn(2, 4, 4) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': pred} | |||||
target_dict = {'target': target, 'seq_len': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
metric.get_metric() | |||||
def test_case2(self): | |||||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': target} | |||||
target_dict = {'target': target, 'seq_len': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | |||||
class TestUsefulFunctions(unittest.TestCase): | class TestUsefulFunctions(unittest.TestCase): | ||||
# 测试metrics.py中一些看上去挺有用的函数 | # 测试metrics.py中一些看上去挺有用的函数 | ||||