diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 351b210d..2e7cbfcf 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -206,7 +206,7 @@ class POS(API): prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", - seq_lens="word_seq_origin_len") + seq_len="word_seq_origin_len") evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, {"truth": torch.Tensor(truth)}) test_result = evaluator.get_metric() @@ -300,15 +300,15 @@ class CWS(API): pp(te_dataset) from ..core.tester import Tester - from ..core.metrics import BMESF1PreRecMetric + from ..core.metrics import SpanFPreRecMetric - tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, + tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64, verbose=0) eval_res = tester.test() - f1 = eval_res['BMESF1PreRecMetric']['f'] - pre = eval_res['BMESF1PreRecMetric']['pre'] - rec = eval_res['BMESF1PreRecMetric']['rec'] + f1 = eval_res['SpanFPreRecMetric']['f'] + pre = eval_res['SpanFPreRecMetric']['pre'] + rec = eval_res['SpanFPreRecMetric']['rec'] # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) return {"F1": f1, "precision": pre, "recall": rec} diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 52e716f6..3c5b3f42 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -20,7 +20,7 @@ from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .instance import Instance from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward -from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric +from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 5cdfdc44..0354e7cc 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -618,149 +618,6 @@ class SpanFPreRecMetric(MetricBase): return f, pre, rec -class BMESF1PreRecMetric(MetricBase): - """ - 别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` - - 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, - next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B - - +-------+---------+----------+----------+---------+---------+ - | | next_B | next_M | next_E | next_S | end | - +=======+=========+==========+==========+=========+=========+ - | start | 合法 | next_M=B | next_E=S | 合法 | -- | - +-------+---------+----------+----------+---------+---------+ - | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | - +-------+---------+----------+----------+---------+---------+ - | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | - +-------+---------+----------+----------+---------+---------+ - | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | - +-------+---------+----------+----------+---------+---------+ - | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | - +-------+---------+----------+----------+---------+---------+ - - 举例: - prediction为BSEMS,会被认为是SSSSS. - - 本Metric不检验target的合法性,请务必保证target的合法性。 - pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 - target形状为 (batch_size, max_len) - seq_lens形状为 (batch_size, ) - - 需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 - - :param b_idx: int, Begin标签所对应的tag idx. - :param m_idx: int, Middle标签所对应的tag idx. - :param e_idx: int, End标签所对应的tag idx. - :param s_idx: int, Single标签所对应的tag idx - :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 - :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 - :param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 - """ - - def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): - super().__init__() - - self._init_param_map(pred=pred, target=target, seq_len=seq_len) - - self.yt_wordnum = 0 - self.yp_wordnum = 0 - self.corr_num = 0 - - self.b_idx = b_idx - self.m_idx = m_idx - self.e_idx = e_idx - self.s_idx = s_idx - # 还原init处介绍的矩阵 - self._valida_matrix = { - -1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx - self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], - self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], - self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], - self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], - } - - def _validate_tags(self, tags): - """ - 给定一个tag的Tensor,返回合法tag - - :param torch.Tensor tags: [seq_len] - :return: 返回修改为合法tag的list - """ - assert len(tags)!=0 - assert isinstance(tags, torch.Tensor) and len(tags.size())==1 - padded_tags = [-1, *tags.tolist(), -1] - for idx in range(len(padded_tags)-1): - cur_tag = padded_tags[idx] - if cur_tag not in self._valida_matrix: - cur_tag = self.s_idx - if padded_tags[idx+1] not in self._valida_matrix: - padded_tags[idx+1] = self.s_idx - next_tag = padded_tags[idx+1] - shift_tag = self._valida_matrix[cur_tag][next_tag] - if shift_tag[0]!=-1: - padded_tags[idx+shift_tag[0]] = shift_tag[1] - - return padded_tags[1:-1] - - def evaluate(self, pred, target, seq_len): - """evaluate函数将针对一个批次的预测结果做评价指标的累计 - - :param pred: [batch, seq_len] 或者 [batch, seq_len, 4] - :param target: [batch, seq_len] - :param seq_len: [batch] - :return: - """ - if not isinstance(pred, torch.Tensor): - raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(pred)}.") - if not isinstance(target, torch.Tensor): - raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(target)}.") - - if not isinstance(seq_len, torch.Tensor): - raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_len)}.") - - if pred.size() == target.size() and len(target.size()) == 2: - pass - elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: - pred = pred.argmax(dim=-1) - else: - raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " - f"size:{pred.size()}, target should have size: {pred.size()} or " - f"{pred.size()[:-1]}, got {target.size()}.") - - for idx in range(len(pred)): - target_tags = target[idx][:seq_len[idx]].tolist() - pred_tags = pred[idx][:seq_len[idx]] - pred_tags = self._validate_tags(pred_tags) - start_idx = 0 - for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): - if t_tag in (self.s_idx, self.e_idx): - self.yt_wordnum += 1 - corr_flag = True - for i in range(start_idx, t_idx+1): - if target_tags[i]!=pred_tags[i]: - corr_flag = False - if corr_flag: - self.corr_num += 1 - start_idx = t_idx + 1 - if p_tag in (self.s_idx, self.e_idx): - self.yp_wordnum += 1 - - def get_metric(self, reset=True): - """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" - P = self.corr_num / (self.yp_wordnum + 1e-12) - R = self.corr_num / (self.yt_wordnum + 1e-12) - F = 2 * P * R / (P + R + 1e-12) - evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} - if reset: - self.yp_wordnum = 0 - self.yt_wordnum = 0 - self.corr_num = 0 - return evaluate_result - def _prepare_metrics(metrics): """ diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 0d8ec25a..2c9080b2 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -4,12 +4,12 @@ from torch import nn from ..utils import initial_parameter -def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): +def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): """ 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 - :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 - "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 + :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 :param str encoding_type: 支持"bio", "bmes", "bmeso"。 :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); @@ -17,12 +17,12 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): 为False, 返回的结果中不含与开始结尾相关的内容 :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 """ - num_tags = len(id2label) + num_tags = len(id2target) start_idx = num_tags end_idx = num_tags + 1 encoding_type = encoding_type.lower() allowed_trans = [] - id_label_lst = list(id2label.items()) + id_label_lst = list(id2target.items()) if include_start_end: id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] def split_tag_label(from_label): @@ -160,7 +160,7 @@ class ConditionalRandomField(nn.Module): if allowed_transitions is None: constrain = torch.zeros(num_tags + 2, num_tags + 2) else: - constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) + constrain = torch.full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) for from_tag_id, to_tag_id in allowed_transitions: constrain[from_tag_id, to_tag_id] = 0 self._constrain = nn.Parameter(constrain, requires_grad=False) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index babfb4aa..db508e39 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -351,31 +351,6 @@ class SpanF1PreRecMetric(unittest.TestCase): # fastnlp_bmes_metric.get_metric()) -class TestBMESF1PreRecMetric(unittest.TestCase): - def test_case1(self): - seq_lens = torch.LongTensor([4, 2]) - pred = torch.randn(2, 4, 4) - target = torch.LongTensor([[0, 1, 2, 3], - [3, 3, 0, 0]]) - pred_dict = {'pred': pred} - target_dict = {'target': target, 'seq_len': seq_lens} - - metric = BMESF1PreRecMetric() - metric(pred_dict, target_dict) - metric.get_metric() - - def test_case2(self): - # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} - seq_lens = torch.LongTensor([4, 2]) - target = torch.LongTensor([[0, 1, 2, 3], - [3, 3, 0, 0]]) - pred_dict = {'pred': target} - target_dict = {'target': target, 'seq_len': seq_lens} - - metric = BMESF1PreRecMetric() - metric(pred_dict, target_dict) - self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) - class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数