|
@@ -822,3 +822,154 @@ def pred_topk(y_prob, k=1): |
|
|
(1, k)) |
|
|
(1, k)) |
|
|
y_prob_topk = y_prob[x_axis_index, y_pred_topk] |
|
|
y_prob_topk = y_prob[x_axis_index, y_pred_topk] |
|
|
return y_pred_topk, y_prob_topk |
|
|
return y_pred_topk, y_prob_topk |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQuADMetric(MetricBase): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None, |
|
|
|
|
|
beta=1, right_open=False, print_predict_stat=False): |
|
|
|
|
|
""" |
|
|
|
|
|
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
|
|
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
|
|
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
|
|
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
|
|
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 |
|
|
|
|
|
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 |
|
|
|
|
|
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 |
|
|
|
|
|
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 |
|
|
|
|
|
""" |
|
|
|
|
|
super(SQuADMetric, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end) |
|
|
|
|
|
|
|
|
|
|
|
self.print_predict_stat = print_predict_stat |
|
|
|
|
|
|
|
|
|
|
|
self.no_ans_correct = 0 |
|
|
|
|
|
self.no_ans_wrong = 0 |
|
|
|
|
|
|
|
|
|
|
|
self.has_ans_correct = 0 |
|
|
|
|
|
self.has_ans_wrong = 0 |
|
|
|
|
|
|
|
|
|
|
|
self.has_ans_f = 0. |
|
|
|
|
|
|
|
|
|
|
|
self.no2no = 0 |
|
|
|
|
|
self.no2yes = 0 |
|
|
|
|
|
self.yes2no = 0 |
|
|
|
|
|
self.yes2yes = 0 |
|
|
|
|
|
|
|
|
|
|
|
self.f_beta = beta |
|
|
|
|
|
|
|
|
|
|
|
self.right_open = right_open |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self, pred_start, pred_end, target_start, target_end): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param pred_start: [batch, seq_len] |
|
|
|
|
|
:param pred_end: [batch, seq_len] |
|
|
|
|
|
:param target_start: [batch] |
|
|
|
|
|
:param target_end: [batch] |
|
|
|
|
|
:param labels: [batch] |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() |
|
|
|
|
|
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() |
|
|
|
|
|
start, end = [], [] |
|
|
|
|
|
max_len = pred_start.size(1) |
|
|
|
|
|
t_start = target_start.cpu().tolist() |
|
|
|
|
|
t_end = target_end.cpu().tolist() |
|
|
|
|
|
|
|
|
|
|
|
for s, e in zip(start_inference, end_inference): |
|
|
|
|
|
start.append(min(s, e)) |
|
|
|
|
|
end.append(max(s, e)) |
|
|
|
|
|
for s, e, ts, te in zip(start, end, t_start, t_end): |
|
|
|
|
|
if not self.right_open: |
|
|
|
|
|
e += 1 |
|
|
|
|
|
te += 1 |
|
|
|
|
|
if ts == 0 and te == int(not self.right_open): |
|
|
|
|
|
if s == 0 and e == int(not self.right_open): |
|
|
|
|
|
self.no_ans_correct += 1 |
|
|
|
|
|
self.no2no += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
self.no_ans_wrong += 1 |
|
|
|
|
|
self.no2yes += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
if s == 0 and e == int(not self.right_open): |
|
|
|
|
|
self.yes2no += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
self.yes2yes += 1 |
|
|
|
|
|
|
|
|
|
|
|
if s == ts and e == te: |
|
|
|
|
|
self.has_ans_correct += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
self.has_ans_wrong += 1 |
|
|
|
|
|
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) |
|
|
|
|
|
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) |
|
|
|
|
|
a, b = torch.tensor(a), torch.tensor(b) |
|
|
|
|
|
|
|
|
|
|
|
TP = int(torch.sum(a * b)) |
|
|
|
|
|
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 |
|
|
|
|
|
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 |
|
|
|
|
|
|
|
|
|
|
|
if pre + rec > 0: |
|
|
|
|
|
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) |
|
|
|
|
|
else: |
|
|
|
|
|
f = 0 |
|
|
|
|
|
self.has_ans_f += f |
|
|
|
|
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
|
|
evaluate_result = {} |
|
|
|
|
|
|
|
|
|
|
|
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: |
|
|
|
|
|
return evaluate_result |
|
|
|
|
|
|
|
|
|
|
|
evaluate_result['EM'] = 0 |
|
|
|
|
|
evaluate_result[f'f_{self.f_beta}'] = 0 |
|
|
|
|
|
|
|
|
|
|
|
flag = 0 |
|
|
|
|
|
|
|
|
|
|
|
if self.no_ans_correct + self.no_ans_wrong > 0: |
|
|
|
|
|
evaluate_result[f'noAns-f_{self.f_beta}'] = \ |
|
|
|
|
|
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) |
|
|
|
|
|
evaluate_result['noAns-EM'] = \ |
|
|
|
|
|
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) |
|
|
|
|
|
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] |
|
|
|
|
|
evaluate_result['EM'] += evaluate_result['noAns-EM'] |
|
|
|
|
|
flag += 1 |
|
|
|
|
|
|
|
|
|
|
|
if self.has_ans_correct + self.has_ans_wrong > 0: |
|
|
|
|
|
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ |
|
|
|
|
|
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) |
|
|
|
|
|
evaluate_result['hasAns-EM'] = \ |
|
|
|
|
|
round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3) |
|
|
|
|
|
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] |
|
|
|
|
|
evaluate_result['EM'] += evaluate_result['hasAns-EM'] |
|
|
|
|
|
flag += 1 |
|
|
|
|
|
|
|
|
|
|
|
if self.print_predict_stat: |
|
|
|
|
|
evaluate_result['no2no'] = self.no2no |
|
|
|
|
|
evaluate_result['no2yes'] = self.no2yes |
|
|
|
|
|
evaluate_result['yes2no'] = self.yes2no |
|
|
|
|
|
evaluate_result['yes2yes'] = self.yes2yes |
|
|
|
|
|
|
|
|
|
|
|
if flag <= 0: |
|
|
|
|
|
return evaluate_result |
|
|
|
|
|
|
|
|
|
|
|
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) |
|
|
|
|
|
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) |
|
|
|
|
|
|
|
|
|
|
|
if reset: |
|
|
|
|
|
self.no_ans_correct = 0 |
|
|
|
|
|
self.no_ans_wrong = 0 |
|
|
|
|
|
|
|
|
|
|
|
self.has_ans_correct = 0 |
|
|
|
|
|
self.has_ans_wrong = 0 |
|
|
|
|
|
|
|
|
|
|
|
self.has_ans_f = 0. |
|
|
|
|
|
|
|
|
|
|
|
self.no2no = 0 |
|
|
|
|
|
self.no2yes = 0 |
|
|
|
|
|
self.yes2no = 0 |
|
|
|
|
|
self.yes2yes = 0 |
|
|
|
|
|
|
|
|
|
|
|
return evaluate_result |
|
|
|
|
|
|