Browse Source

Merge branch 'dev' of github.com:choosewhatulike/fastNLP-private into dev

tags/v0.4.10
yh_cc 5 years ago
parent
commit
cdfd3d3388
2 changed files with 147 additions and 115 deletions
  1. +136
    -104
      fastNLP/core/metrics.py
  2. +11
    -11
      test/core/test_metrics.py

+ 136
- 104
fastNLP/core/metrics.py View File

@@ -23,7 +23,7 @@ from fastNLP.core.vocabulary import Vocabulary


class MetricBase(object):
"""Base class for all metrics.
"""所有metrics的基类

所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。
@@ -94,18 +94,22 @@ class MetricBase(object):
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中


``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``.
``pred_dict`` is the output of ``forward()`` or prediction function of a model.
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``.
``MetricBase`` will do the following type checks:
``MetricBase`` 将会在输入的字典``pred_dict``和``target_dict``中进行检查.
``pred_dict`` 是模型当中``forward()``函数或者``predict()``函数的返回值.
``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的``is_target``被设置为True.

1. whether self.evaluate has varargs, which is not supported.
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``.
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``.
``MetricBase`` 会进行以下的类型检测:

Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering
will be conducted.)
1. self.evaluate当中是否有varargs, 这是不支持的.
2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``.
3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``.

除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数
如果kwargs是self.evaluate的参数,则不会检测


self.evaluate将计算一个批次(batch)的评价指标,并累计
self.get_metric将统计当前的评价指标并返回评价结果

"""
def __init__(self):
@@ -116,10 +120,10 @@ class MetricBase(object):
raise NotImplementedError

def _init_param_map(self, key_map=None, **kwargs):
"""Check the validity of key_map and other param map. Add these into self.param_map
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map

:param key_map: dict
:param kwargs:
:param dict key_map: 表示key的映射关系
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
:return: None
"""
value_counter = defaultdict(set)
@@ -162,17 +166,16 @@ class MetricBase(object):

def __call__(self, pred_dict, target_dict):
"""

This method will call self.evaluate method.
Before calling self.evaluate, it will first check the validity of output_dict, target_dict
(1) whether params needed by self.evaluate is not included in output_dict,target_dict.
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.)
:param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
这个方法会调用self.evaluate 方法.
在调用之前,会进行以下检测:
1. self.evaluate当中是否有varargs, 这是不支持的.
2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``.
3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``.

除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数
如果kwargs是self.evaluate的参数,则不会检测
:param pred_dict: 模型的forward函数或者predict函数返回的dict
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容)
:return:
"""
if not callable(self.evaluate):
@@ -244,25 +247,29 @@ class MetricBase(object):


class AccuracyMetric(MetricBase):
"""Accuracy Metric

"""
def __init__(self, pred=None, target=None, seq_lens=None):
"""准确率Metric"""
def __init__(self, pred=None, target=None, seq_len=None):
"""
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
:param seq_len: 参数映射表中`seq_lens`的映射关系,None表示映射关系为`seq_len`->`seq_len`
"""
super().__init__()

self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self._init_param_map(pred=pred, target=target, seq_len=seq_len)

self.total = 0
self.acc_count = 0

def evaluate(self, pred, target, seq_lens=None):
"""
def evaluate(self, pred, target, seq_len=None):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计

:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]),
torch.Size([B, max_len, n_classes])
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]),
torch.Size([B, max_len])
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略.

"""
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
@@ -273,12 +280,12 @@ class AccuracyMetric(MetricBase):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens)
if seq_len is not None:
masks = seq_lens_to_masks(seq_lens=seq_len)
else:
masks = None

@@ -300,10 +307,10 @@ class AccuracyMetric(MetricBase):
self.total += np.prod(list(pred.size()))

def get_metric(self, reset=True):
"""Returns computed metric.
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:param bool reset: whether to recount next time.
:return evaluate_result: {"acc": float}
:param bool reset: 在调用完get_metric后是否清空评价指标统计量.
:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)}
if reset:
@@ -311,7 +318,8 @@ class AccuracyMetric(MetricBase):
self.total = 0
return evaluate_result

def bmes_tag_to_spans(tags, ignore_labels=None):

def _bmes_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间)
@@ -339,7 +347,8 @@ def bmes_tag_to_spans(tags, ignore_labels=None):
if span[0] not in ignore_labels
]

def bmeso_tag_to_spans(tags, ignore_labels=None):

def _bmeso_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)
@@ -369,7 +378,8 @@ def bmeso_tag_to_spans(tags, ignore_labels=None):
if span[0] not in ignore_labels
]

def bio_tag_to_spans(tags, ignore_labels=None):

def _bio_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)
@@ -394,9 +404,7 @@ def bio_tag_to_spans(tags, ignore_labels=None):
else:
spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels]
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels]


class SpanFPreRecMetric(MetricBase):
@@ -428,23 +436,23 @@ class SpanFPreRecMetric(MetricBase):
}

"""
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None,
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
"""

:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
:param Vocabulary tag_vocab: 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。
:param encoding_type: str, 目前支持bio, bmes
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。
:param str encoding_type: 目前支持bio, bmes
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
label的f1, pre, rec
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
"""
encoding_type = encoding_type.lower()
@@ -456,11 +464,11 @@ class SpanFPreRecMetric(MetricBase):

self.encoding_type = encoding_type
if self.encoding_type == 'bmes':
self.tag_to_span_func = bmes_tag_to_spans
self.tag_to_span_func = _bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = bio_tag_to_spans
self.tag_to_span_func = _bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = bmeso_tag_to_spans
self.tag_to_span_func = _bmeso_tag_to_spans
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")

@@ -471,7 +479,7 @@ class SpanFPreRecMetric(MetricBase):
self.only_gross = only_gross

super().__init__()
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self._init_param_map(pred=pred, target=target, seq_len=seq_len)

self.tag_vocab = tag_vocab

@@ -479,12 +487,12 @@ class SpanFPreRecMetric(MetricBase):
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

def evaluate(self, pred, target, seq_lens):
"""
A lot of design idea comes from allennlp's measure
:param pred:
:param target:
:param seq_lens:
def evaluate(self, pred, target, seq_len):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
:param target: [batch, seq_len], 真实值
:param seq_len: [batch] 文本长度标记
:return:
"""
if not isinstance(pred, torch.Tensor):
@@ -494,9 +502,9 @@ class SpanFPreRecMetric(MetricBase):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_lens, torch.Tensor):
if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
f"got {type(seq_len)}.")

if pred.size() == target.size() and len(target.size()) == 2:
pass
@@ -513,8 +521,8 @@ class SpanFPreRecMetric(MetricBase):

batch_size = pred.size(0)
for i in range(batch_size):
pred_tags = pred[i, :int(seq_lens[i])].tolist()
gold_tags = target[i, :int(seq_lens[i])].tolist()
pred_tags = pred[i, :int(seq_len[i])].tolist()
gold_tags = target[i, :int(seq_len[i])].tolist()

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
@@ -532,8 +540,9 @@ class SpanFPreRecMetric(MetricBase):
self._false_negatives[span[0]] += 1

def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {}
if not self.only_gross or self.f_type=='macro':
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
@@ -590,6 +599,7 @@ class SpanFPreRecMetric(MetricBase):

return f, pre, rec


class BMESF1PreRecMetric(MetricBase):
"""
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
@@ -619,7 +629,7 @@ class BMESF1PreRecMetric(MetricBase):

"""

def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None):
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None):
"""
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。

@@ -629,11 +639,11 @@ class BMESF1PreRecMetric(MetricBase):
:param s_idx: int, Single标签所对应的tag idx
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。
"""
super().__init__()

self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)
self._init_param_map(pred=pred, target=target, seq_len=seq_len)

self.yt_wordnum = 0
self.yp_wordnum = 0
@@ -656,7 +666,7 @@ class BMESF1PreRecMetric(MetricBase):
"""
给定一个tag的Tensor,返回合法tag

:param tags: Tensor, shape: (seq_len, )
:param torch.Tensor tags: [seq_len]
:return: 返回修改为合法tag的list
"""
assert len(tags)!=0
@@ -675,7 +685,14 @@ class BMESF1PreRecMetric(MetricBase):

return padded_tags[1:-1]

def evaluate(self, pred, target, seq_lens):
def evaluate(self, pred, target, seq_len):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计

:param pred: [batch, seq_len] 或者 [batch, seq_len, 4]
:param target: [batch, seq_len]
:param seq_len: [batch]
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
@@ -683,9 +700,9 @@ class BMESF1PreRecMetric(MetricBase):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_lens, torch.Tensor):
if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
f"got {type(seq_len)}.")

if pred.size() == target.size() and len(target.size()) == 2:
pass
@@ -697,7 +714,7 @@ class BMESF1PreRecMetric(MetricBase):
f"{pred.size()[:-1]}, got {target.size()}.")

for idx in range(len(pred)):
seq_len = seq_lens[idx]
seq_len = seq_len[idx]
target_tags = target[idx][:seq_len].tolist()
pred_tags = pred[idx][:seq_len]
pred_tags = self._validate_tags(pred_tags)
@@ -716,6 +733,7 @@ class BMESF1PreRecMetric(MetricBase):
self.yp_wordnum += 1

def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
P = self.corr_num / (self.yp_wordnum + 1e-12)
R = self.corr_num / (self.yt_wordnum + 1e-12)
F = 2 * P * R / (P + R + 1e-12)
@@ -758,7 +776,7 @@ def _prepare_metrics(metrics):
return _metrics


def accuracy_topk(y_true, y_prob, k=1):
def _accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable labels in y_prob.

:param y_true: ndarray, true label, [n_samples]
@@ -774,7 +792,7 @@ def accuracy_topk(y_true, y_prob, k=1):
return acc


def pred_topk(y_prob, k=1):
def _pred_topk(y_prob, k=1):
"""Return top-k predicted labels and corresponding probabilities.

:param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels
@@ -793,22 +811,24 @@ def pred_topk(y_prob, k=1):


class SQuADMetric(MetricBase):
"""SQuAD数据集metric
"""

def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None,
beta=1, right_open=False, print_predict_stat=False):
def __init__(self, pred1=None, pred2=None, target1=None, target2=None,
beta=1, right_open=True, print_predict_stat=False):
"""
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1`
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2`
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1`
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2`
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
"""
super(SQuADMetric, self).__init__()

self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end)
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2)

self.print_predict_stat = print_predict_stat

@@ -829,18 +849,29 @@ class SQuADMetric(MetricBase):

self.right_open = right_open

def evaluate(self, pred_start, pred_end, target_start, target_end):
"""
def evaluate(self, pred1, pred2, target1, target2):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计

:param pred_start: [batch, seq_len]
:param pred_end: [batch, seq_len]
:param target_start: [batch]
:param target_end: [batch]
:param labels: [batch]
:return:
:param pred1: [batch]或者[batch, seq_len], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0
:param pred2: [batch]或者[batch, seq_len] 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param target1: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0
:param target2: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:return: None
"""
start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
pred_start = pred1
pred_end = pred2
target_start = target1
target_end = target2

if len(pred_start.size()) == 2:
start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
else:
start_inference = pred_start.cpu().tolist()
if len(pred_end.size()) == 2:
end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
else:
end_inference = pred_end.cpu().tolist()

start, end = [], []
max_len = pred_start.size(1)
t_start = target_start.cpu().tolist()
@@ -885,6 +916,7 @@ class SQuADMetric(MetricBase):
self.has_ans_f += f

def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {}

if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0:


+ 11
- 11
test/core/test_metrics.py View File

@@ -5,7 +5,7 @@ import torch

from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import BMESF1PreRecMetric
from fastNLP.core.metrics import pred_topk, accuracy_topk
from fastNLP.core.metrics import _pred_topk, _accuracy_topk


class TestAccuracyMetric(unittest.TestCase):
@@ -134,8 +134,8 @@ class TestAccuracyMetric(unittest.TestCase):

class SpanF1PreRecMetric(unittest.TestCase):
def test_case1(self):
from fastNLP.core.metrics import bmes_tag_to_spans
from fastNLP.core.metrics import bio_tag_to_spans
from fastNLP.core.metrics import _bmes_tag_to_spans
from fastNLP.core.metrics import _bio_tag_to_spans

bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8']
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8']
@@ -145,8 +145,8 @@ class SpanF1PreRecMetric(unittest.TestCase):
expect_bio_res = set()
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)),
('6', (4, 5)), ('7', (6, 7))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
@@ -161,8 +161,8 @@ class SpanF1PreRecMetric(unittest.TestCase):

def test_case2(self):
# 测试不带label的
from fastNLP.core.metrics import bmes_tag_to_spans
from fastNLP.core.metrics import bio_tag_to_spans
from fastNLP.core.metrics import _bmes_tag_to_spans
from fastNLP.core.metrics import _bio_tag_to_spans

bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E']
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O']
@@ -170,8 +170,8 @@ class SpanF1PreRecMetric(unittest.TestCase):
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))])
expect_bio_res = set()
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
@@ -366,7 +366,7 @@ class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数
def test_case_1(self):
# multi-class
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = pred_topk(np.random.randint(0, 3, size=(10, 1)))
_ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1)))

# 跑通即可

Loading…
Cancel
Save