Browse Source

add SQuAD metric

tags/v0.4.10
xuyige 5 years ago
parent
commit
55f65c3993
1 changed files with 151 additions and 0 deletions
  1. +151
    -0
      fastNLP/core/metrics.py

+ 151
- 0
fastNLP/core/metrics.py View File

@@ -822,3 +822,154 @@ def pred_topk(y_prob, k=1):
(1, k)) (1, k))
y_prob_topk = y_prob[x_axis_index, y_pred_topk] y_prob_topk = y_prob[x_axis_index, y_pred_topk]
return y_pred_topk, y_prob_topk return y_pred_topk, y_prob_topk


class SQuADMetric(MetricBase):

def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None,
beta=1, right_open=False, print_predict_stat=False):
"""
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
"""
super(SQuADMetric, self).__init__()

self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end)

self.print_predict_stat = print_predict_stat

self.no_ans_correct = 0
self.no_ans_wrong = 0

self.has_ans_correct = 0
self.has_ans_wrong = 0

self.has_ans_f = 0.

self.no2no = 0
self.no2yes = 0
self.yes2no = 0
self.yes2yes = 0

self.f_beta = beta

self.right_open = right_open

def evaluate(self, pred_start, pred_end, target_start, target_end):
"""

:param pred_start: [batch, seq_len]
:param pred_end: [batch, seq_len]
:param target_start: [batch]
:param target_end: [batch]
:param labels: [batch]
:return:
"""
start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
start, end = [], []
max_len = pred_start.size(1)
t_start = target_start.cpu().tolist()
t_end = target_end.cpu().tolist()

for s, e in zip(start_inference, end_inference):
start.append(min(s, e))
end.append(max(s, e))
for s, e, ts, te in zip(start, end, t_start, t_end):
if not self.right_open:
e += 1
te += 1
if ts == 0 and te == int(not self.right_open):
if s == 0 and e == int(not self.right_open):
self.no_ans_correct += 1
self.no2no += 1
else:
self.no_ans_wrong += 1
self.no2yes += 1
else:
if s == 0 and e == int(not self.right_open):
self.yes2no += 1
else:
self.yes2yes += 1

if s == ts and e == te:
self.has_ans_correct += 1
else:
self.has_ans_wrong += 1
a = [0] * s + [1] * (e - s) + [0] * (max_len - e)
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te)
a, b = torch.tensor(a), torch.tensor(b)

TP = int(torch.sum(a * b))
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0

if pre + rec > 0:
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec)
else:
f = 0
self.has_ans_f += f

def get_metric(self, reset=True):
evaluate_result = {}

if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0:
return evaluate_result

evaluate_result['EM'] = 0
evaluate_result[f'f_{self.f_beta}'] = 0

flag = 0

if self.no_ans_correct + self.no_ans_wrong > 0:
evaluate_result[f'noAns-f_{self.f_beta}'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
evaluate_result['noAns-EM'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['noAns-EM']
flag += 1

if self.has_ans_correct + self.has_ans_wrong > 0:
evaluate_result[f'hasAns-f_{self.f_beta}'] = \
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3)
evaluate_result['hasAns-EM'] = \
round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3)
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['hasAns-EM']
flag += 1

if self.print_predict_stat:
evaluate_result['no2no'] = self.no2no
evaluate_result['no2yes'] = self.no2yes
evaluate_result['yes2no'] = self.yes2no
evaluate_result['yes2yes'] = self.yes2yes

if flag <= 0:
return evaluate_result

evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3)
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3)

if reset:
self.no_ans_correct = 0
self.no_ans_wrong = 0

self.has_ans_correct = 0
self.has_ans_wrong = 0

self.has_ans_f = 0.

self.no2no = 0
self.no2yes = 0
self.yes2no = 0
self.yes2yes = 0

return evaluate_result


Loading…
Cancel
Save