From 55783314bc1380104e4930e09f089c7ea9b86fa9 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 23 Apr 2019 23:07:59 +0800 Subject: [PATCH] update documents on metrics --- fastNLP/core/metrics.py | 240 +++++++++++++++++++++----------------- test/core/test_metrics.py | 22 ++-- 2 files changed, 147 insertions(+), 115 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 314be0d9..71ca2926 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -14,7 +14,7 @@ from fastNLP.core.vocabulary import Vocabulary class MetricBase(object): - """Base class for all metrics. + """所有metrics的基类 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 @@ -85,18 +85,22 @@ class MetricBase(object): return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 - ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. - ``pred_dict`` is the output of ``forward()`` or prediction function of a model. - ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. - ``MetricBase`` will do the following type checks: + ``MetricBase`` 将会在输入的字典``pred_dict``和``target_dict``中进行检查. + ``pred_dict`` 是模型当中``forward()``函数或者``predict()``函数的返回值. + ``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的``is_target``被设置为True. - 1. whether self.evaluate has varargs, which is not supported. - 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. - 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. + ``MetricBase`` 会进行以下的类型检测: - Besides, before passing params into self.evaluate, this function will filter out params from output_dict and - target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering - will be conducted.) + 1. self.evaluate当中是否有varargs, 这是不支持的. + 2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``. + 3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``. + + 除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数 + 如果kwargs是self.evaluate的参数,则不会检测 + + + self.evaluate将计算一个批次(batch)的评价指标,并累计 + self.get_metric将统计当前的评价指标并返回评价结果 """ def __init__(self): @@ -107,10 +111,10 @@ class MetricBase(object): raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): - """Check the validity of key_map and other param map. Add these into self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self.param_map - :param key_map: dict - :param kwargs: + :param dict key_map: 表示key的映射关系 + :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 :return: None """ value_counter = defaultdict(set) @@ -153,17 +157,16 @@ class MetricBase(object): def __call__(self, pred_dict, target_dict): """ - - This method will call self.evaluate method. - Before calling self.evaluate, it will first check the validity of output_dict, target_dict - (1) whether params needed by self.evaluate is not included in output_dict,target_dict. - (2) whether params needed by self.evaluate duplicate in pred_dict, target_dict - (3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) - Besides, before passing params into self.evaluate, this function will filter out params from output_dict and - target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering - will be conducted.) - :param pred_dict: usually the output of forward or prediction function - :param target_dict: usually features set as target.. + 这个方法会调用self.evaluate 方法. + 在调用之前,会进行以下检测: + 1. self.evaluate当中是否有varargs, 这是不支持的. + 2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``. + 3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``. + + 除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数 + 如果kwargs是self.evaluate的参数,则不会检测 + :param pred_dict: 模型的forward函数或者predict函数返回的dict + :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :return: """ if not callable(self.evaluate): @@ -235,25 +238,29 @@ class MetricBase(object): class AccuracyMetric(MetricBase): - """Accuracy Metric - - """ - def __init__(self, pred=None, target=None, seq_lens=None): + """准确率Metric""" + def __init__(self, pred=None, target=None, seq_len=None): + """ + :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` + :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + :param seq_len: 参数映射表中`seq_lens`的映射关系,None表示映射关系为`seq_len`->`seq_len` + """ super().__init__() - self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) + self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.total = 0 self.acc_count = 0 - def evaluate(self, pred, target, seq_lens=None): - """ + def evaluate(self, pred, target, seq_len=None): + """evaluate函数将针对一个批次的预测结果做评价指标的累计 - :param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), - torch.Size([B, max_len, n_classes]) - :param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), - torch.Size([B, max_len]) - :param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). + 如果mask也被传进来的话seq_len会被忽略. """ # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value @@ -264,12 +271,12 @@ class AccuracyMetric(MetricBase): raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): + if seq_len is not None and not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_lens)}.") - if seq_lens is not None: - masks = seq_lens_to_masks(seq_lens=seq_lens) + if seq_len is not None: + masks = seq_lens_to_masks(seq_lens=seq_len) else: masks = None @@ -291,10 +298,10 @@ class AccuracyMetric(MetricBase): self.total += np.prod(list(pred.size())) def get_metric(self, reset=True): - """Returns computed metric. + """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. - :param bool reset: whether to recount next time. - :return evaluate_result: {"acc": float} + :param bool reset: 在调用完get_metric后是否清空评价指标统计量. + :return dict evaluate_result: {"acc": float} """ evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)} if reset: @@ -302,7 +309,8 @@ class AccuracyMetric(MetricBase): self.total = 0 return evaluate_result -def bmes_tag_to_spans(tags, ignore_labels=None): + +def _bmes_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) @@ -330,7 +338,8 @@ def bmes_tag_to_spans(tags, ignore_labels=None): if span[0] not in ignore_labels ] -def bmeso_tag_to_spans(tags, ignore_labels=None): + +def _bmeso_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 返回[('singer', (1, 4))] (左闭右开区间) @@ -360,7 +369,8 @@ def bmeso_tag_to_spans(tags, ignore_labels=None): if span[0] not in ignore_labels ] -def bio_tag_to_spans(tags, ignore_labels=None): + +def _bio_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 返回[('singer', (1, 4))] (左闭右开区间) @@ -385,9 +395,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): else: spans.append((label, [idx, idx])) prev_bio_tag = bio_tag - return [(span[0], (span[1][0], span[1][1]+1)) - for span in spans - if span[0] not in ignore_labels] + return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels] class SpanFPreRecMetric(MetricBase): @@ -416,23 +424,23 @@ class SpanFPreRecMetric(MetricBase): } """ - def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, + def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, only_gross=True, f_type='micro', beta=1): """ - :param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), + :param Vocabulary tag_vocab: 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. - :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 - :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 - :param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 - :param encoding_type: str, 目前支持bio, bmes - :param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 + :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 + :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 + :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 + :param str encoding_type: 目前支持bio, bmes + :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 个label - :param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 + :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 label的f1, pre, rec - :param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': + :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) - :param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 + :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ encoding_type = encoding_type.lower() @@ -444,11 +452,11 @@ class SpanFPreRecMetric(MetricBase): self.encoding_type = encoding_type if self.encoding_type == 'bmes': - self.tag_to_span_func = bmes_tag_to_spans + self.tag_to_span_func = _bmes_tag_to_spans elif self.encoding_type == 'bio': - self.tag_to_span_func = bio_tag_to_spans + self.tag_to_span_func = _bio_tag_to_spans elif self.encoding_type == 'bmeso': - self.tag_to_span_func = bmeso_tag_to_spans + self.tag_to_span_func = _bmeso_tag_to_spans else: raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") @@ -459,7 +467,7 @@ class SpanFPreRecMetric(MetricBase): self.only_gross = only_gross super().__init__() - self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) + self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.tag_vocab = tag_vocab @@ -467,12 +475,12 @@ class SpanFPreRecMetric(MetricBase): self._false_positives = defaultdict(int) self._false_negatives = defaultdict(int) - def evaluate(self, pred, target, seq_lens): - """ - A lot of design idea comes from allennlp's measure - :param pred: - :param target: - :param seq_lens: + def evaluate(self, pred, target, seq_len): + """evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 + :param target: [batch, seq_len], 真实值 + :param seq_len: [batch] 文本长度标记 :return: """ if not isinstance(pred, torch.Tensor): @@ -482,9 +490,9 @@ class SpanFPreRecMetric(MetricBase): raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - if not isinstance(seq_lens, torch.Tensor): + if not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_lens)}.") + f"got {type(seq_len)}.") if pred.size() == target.size() and len(target.size()) == 2: pass @@ -501,8 +509,8 @@ class SpanFPreRecMetric(MetricBase): batch_size = pred.size(0) for i in range(batch_size): - pred_tags = pred[i, :int(seq_lens[i])].tolist() - gold_tags = target[i, :int(seq_lens[i])].tolist() + pred_tags = pred[i, :int(seq_len[i])].tolist() + gold_tags = target[i, :int(seq_len[i])].tolist() pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] @@ -520,8 +528,9 @@ class SpanFPreRecMetric(MetricBase): self._false_negatives[span[0]] += 1 def get_metric(self, reset=True): + """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" evaluate_result = {} - if not self.only_gross or self.f_type=='macro': + if not self.only_gross or self.f_type == 'macro': tags = set(self._false_negatives.keys()) tags.update(set(self._false_positives.keys())) tags.update(set(self._true_positives.keys())) @@ -578,6 +587,7 @@ class SpanFPreRecMetric(MetricBase): return f, pre, rec + class BMESF1PreRecMetric(MetricBase): """ 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, @@ -607,7 +617,7 @@ class BMESF1PreRecMetric(MetricBase): """ - def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): + def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): """ 需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 @@ -617,11 +627,11 @@ class BMESF1PreRecMetric(MetricBase): :param s_idx: int, Single标签所对应的tag idx :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 - :param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 + :param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 """ super().__init__() - self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) + self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.yt_wordnum = 0 self.yp_wordnum = 0 @@ -644,7 +654,7 @@ class BMESF1PreRecMetric(MetricBase): """ 给定一个tag的Tensor,返回合法tag - :param tags: Tensor, shape: (seq_len, ) + :param torch.Tensor tags: [seq_len] :return: 返回修改为合法tag的list """ assert len(tags)!=0 @@ -663,7 +673,14 @@ class BMESF1PreRecMetric(MetricBase): return padded_tags[1:-1] - def evaluate(self, pred, target, seq_lens): + def evaluate(self, pred, target, seq_len): + """evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param pred: [batch, seq_len] 或者 [batch, seq_len, 4] + :param target: [batch, seq_len] + :param seq_len: [batch] + :return: + """ if not isinstance(pred, torch.Tensor): raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(pred)}.") @@ -671,9 +688,9 @@ class BMESF1PreRecMetric(MetricBase): raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - if not isinstance(seq_lens, torch.Tensor): + if not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_lens)}.") + f"got {type(seq_len)}.") if pred.size() == target.size() and len(target.size()) == 2: pass @@ -685,7 +702,7 @@ class BMESF1PreRecMetric(MetricBase): f"{pred.size()[:-1]}, got {target.size()}.") for idx in range(len(pred)): - seq_len = seq_lens[idx] + seq_len = seq_len[idx] target_tags = target[idx][:seq_len].tolist() pred_tags = pred[idx][:seq_len] pred_tags = self._validate_tags(pred_tags) @@ -704,6 +721,7 @@ class BMESF1PreRecMetric(MetricBase): self.yp_wordnum += 1 def get_metric(self, reset=True): + """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" P = self.corr_num / (self.yp_wordnum + 1e-12) R = self.corr_num / (self.yt_wordnum + 1e-12) F = 2 * P * R / (P + R + 1e-12) @@ -746,7 +764,7 @@ def _prepare_metrics(metrics): return _metrics -def accuracy_topk(y_true, y_prob, k=1): +def _accuracy_topk(y_true, y_prob, k=1): """Compute accuracy of y_true matching top-k probable labels in y_prob. :param y_true: ndarray, true label, [n_samples] @@ -762,7 +780,7 @@ def accuracy_topk(y_true, y_prob, k=1): return acc -def pred_topk(y_prob, k=1): +def _pred_topk(y_prob, k=1): """Return top-k predicted labels and corresponding probabilities. :param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels @@ -781,22 +799,24 @@ def pred_topk(y_prob, k=1): class SQuADMetric(MetricBase): + """SQuAD数据集metric + """ - def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None, - beta=1, right_open=False, print_predict_stat=False): + def __init__(self, pred1=None, pred2=None, target1=None, target2=None, + beta=1, right_open=True, print_predict_stat=False): """ - :param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 - :param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) - :param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 - :param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) - :param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 + :param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` + :param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` + :param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` + :param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` + :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 - :param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 - :param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 + :param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 + :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 """ super(SQuADMetric, self).__init__() - self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end) + self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) self.print_predict_stat = print_predict_stat @@ -817,18 +837,29 @@ class SQuADMetric(MetricBase): self.right_open = right_open - def evaluate(self, pred_start, pred_end, target_start, target_end): - """ + def evaluate(self, pred1, pred2, target1, target2): + """evaluate函数将针对一个批次的预测结果做评价指标的累计 - :param pred_start: [batch, seq_len] - :param pred_end: [batch, seq_len] - :param target_start: [batch] - :param target_end: [batch] - :param labels: [batch] - :return: + :param pred1: [batch]或者[batch, seq_len], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 + :param pred2: [batch]或者[batch, seq_len] 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) + :param target1: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 + :param target2: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) + :return: None """ - start_inference = pred_start.max(dim=-1)[1].cpu().tolist() - end_inference = pred_end.max(dim=-1)[1].cpu().tolist() + pred_start = pred1 + pred_end = pred2 + target_start = target1 + target_end = target2 + + if len(pred_start.size()) == 2: + start_inference = pred_start.max(dim=-1)[1].cpu().tolist() + else: + start_inference = pred_start.cpu().tolist() + if len(pred_end.size()) == 2: + end_inference = pred_end.max(dim=-1)[1].cpu().tolist() + else: + end_inference = pred_end.cpu().tolist() + start, end = [], [] max_len = pred_start.size(1) t_start = target_start.cpu().tolist() @@ -873,6 +904,7 @@ class SQuADMetric(MetricBase): self.has_ans_f += f def get_metric(self, reset=True): + """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" evaluate_result = {} if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 4fb2a04e..a0e8f1a5 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -5,7 +5,7 @@ import torch from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.metrics import BMESF1PreRecMetric -from fastNLP.core.metrics import pred_topk, accuracy_topk +from fastNLP.core.metrics import _pred_topk, _accuracy_topk class TestAccuracyMetric(unittest.TestCase): @@ -134,8 +134,8 @@ class TestAccuracyMetric(unittest.TestCase): class SpanF1PreRecMetric(unittest.TestCase): def test_case1(self): - from fastNLP.core.metrics import bmes_tag_to_spans - from fastNLP.core.metrics import bio_tag_to_spans + from fastNLP.core.metrics import _bmes_tag_to_spans + from fastNLP.core.metrics import _bio_tag_to_spans bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] @@ -145,8 +145,8 @@ class SpanF1PreRecMetric(unittest.TestCase): expect_bio_res = set() expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), ('6', (4, 5)), ('7', (6, 7))]) - self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) - self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) + self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) + self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans @@ -161,8 +161,8 @@ class SpanF1PreRecMetric(unittest.TestCase): def test_case2(self): # 测试不带label的 - from fastNLP.core.metrics import bmes_tag_to_spans - from fastNLP.core.metrics import bio_tag_to_spans + from fastNLP.core.metrics import _bmes_tag_to_spans + from fastNLP.core.metrics import _bio_tag_to_spans bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] @@ -170,8 +170,8 @@ class SpanF1PreRecMetric(unittest.TestCase): expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) expect_bio_res = set() expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) - self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) - self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) + self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) + self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans @@ -366,7 +366,7 @@ class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数 def test_case_1(self): # multi-class - _ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) - _ = pred_topk(np.random.randint(0, 3, size=(10, 1))) + _ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) + _ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) # 跑通即可