|
|
@@ -14,7 +14,7 @@ from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
class MetricBase(object): |
|
|
|
"""Base class for all metrics. |
|
|
|
"""所有metrics的基类 |
|
|
|
|
|
|
|
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 |
|
|
|
|
|
|
@@ -85,18 +85,22 @@ class MetricBase(object): |
|
|
|
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 |
|
|
|
|
|
|
|
|
|
|
|
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. |
|
|
|
``pred_dict`` is the output of ``forward()`` or prediction function of a model. |
|
|
|
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. |
|
|
|
``MetricBase`` will do the following type checks: |
|
|
|
``MetricBase`` 将会在输入的字典``pred_dict``和``target_dict``中进行检查. |
|
|
|
``pred_dict`` 是模型当中``forward()``函数或者``predict()``函数的返回值. |
|
|
|
``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的``is_target``被设置为True. |
|
|
|
|
|
|
|
1. whether self.evaluate has varargs, which is not supported. |
|
|
|
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. |
|
|
|
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. |
|
|
|
``MetricBase`` 会进行以下的类型检测: |
|
|
|
|
|
|
|
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and |
|
|
|
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering |
|
|
|
will be conducted.) |
|
|
|
1. self.evaluate当中是否有varargs, 这是不支持的. |
|
|
|
2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``. |
|
|
|
3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``. |
|
|
|
|
|
|
|
除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数 |
|
|
|
如果kwargs是self.evaluate的参数,则不会检测 |
|
|
|
|
|
|
|
|
|
|
|
self.evaluate将计算一个批次(batch)的评价指标,并累计 |
|
|
|
self.get_metric将统计当前的评价指标并返回评价结果 |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
@@ -107,10 +111,10 @@ class MetricBase(object): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def _init_param_map(self, key_map=None, **kwargs): |
|
|
|
"""Check the validity of key_map and other param map. Add these into self.param_map |
|
|
|
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map |
|
|
|
|
|
|
|
:param key_map: dict |
|
|
|
:param kwargs: |
|
|
|
:param dict key_map: 表示key的映射关系 |
|
|
|
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 |
|
|
|
:return: None |
|
|
|
""" |
|
|
|
value_counter = defaultdict(set) |
|
|
@@ -153,17 +157,16 @@ class MetricBase(object): |
|
|
|
|
|
|
|
def __call__(self, pred_dict, target_dict): |
|
|
|
""" |
|
|
|
|
|
|
|
This method will call self.evaluate method. |
|
|
|
Before calling self.evaluate, it will first check the validity of output_dict, target_dict |
|
|
|
(1) whether params needed by self.evaluate is not included in output_dict,target_dict. |
|
|
|
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict |
|
|
|
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) |
|
|
|
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and |
|
|
|
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering |
|
|
|
will be conducted.) |
|
|
|
:param pred_dict: usually the output of forward or prediction function |
|
|
|
:param target_dict: usually features set as target.. |
|
|
|
这个方法会调用self.evaluate 方法. |
|
|
|
在调用之前,会进行以下检测: |
|
|
|
1. self.evaluate当中是否有varargs, 这是不支持的. |
|
|
|
2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``. |
|
|
|
3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``. |
|
|
|
|
|
|
|
除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数 |
|
|
|
如果kwargs是self.evaluate的参数,则不会检测 |
|
|
|
:param pred_dict: 模型的forward函数或者predict函数返回的dict |
|
|
|
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not callable(self.evaluate): |
|
|
@@ -235,25 +238,29 @@ class MetricBase(object): |
|
|
|
|
|
|
|
|
|
|
|
class AccuracyMetric(MetricBase): |
|
|
|
"""Accuracy Metric |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, pred=None, target=None, seq_lens=None): |
|
|
|
"""准确率Metric""" |
|
|
|
def __init__(self, pred=None, target=None, seq_len=None): |
|
|
|
""" |
|
|
|
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` |
|
|
|
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` |
|
|
|
:param seq_len: 参数映射表中`seq_lens`的映射关系,None表示映射关系为`seq_len`->`seq_len` |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) |
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
|
|
|
|
self.total = 0 |
|
|
|
self.acc_count = 0 |
|
|
|
|
|
|
|
def evaluate(self, pred, target, seq_lens=None): |
|
|
|
""" |
|
|
|
def evaluate(self, pred, target, seq_len=None): |
|
|
|
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 |
|
|
|
|
|
|
|
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), |
|
|
|
torch.Size([B, max_len, n_classes]) |
|
|
|
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), |
|
|
|
torch.Size([B, max_len]) |
|
|
|
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
|
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), |
|
|
|
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) |
|
|
|
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), |
|
|
|
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) |
|
|
|
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). |
|
|
|
如果mask也被传进来的话seq_len会被忽略. |
|
|
|
|
|
|
|
""" |
|
|
|
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value |
|
|
@@ -264,12 +271,12 @@ class AccuracyMetric(MetricBase): |
|
|
|
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): |
|
|
|
if seq_len is not None and not isinstance(seq_len, torch.Tensor): |
|
|
|
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(seq_lens)}.") |
|
|
|
|
|
|
|
if seq_lens is not None: |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens) |
|
|
|
if seq_len is not None: |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_len) |
|
|
|
else: |
|
|
|
masks = None |
|
|
|
|
|
|
@@ -291,10 +298,10 @@ class AccuracyMetric(MetricBase): |
|
|
|
self.total += np.prod(list(pred.size())) |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
"""Returns computed metric. |
|
|
|
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. |
|
|
|
|
|
|
|
:param bool reset: whether to recount next time. |
|
|
|
:return evaluate_result: {"acc": float} |
|
|
|
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. |
|
|
|
:return dict evaluate_result: {"acc": float} |
|
|
|
""" |
|
|
|
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)} |
|
|
|
if reset: |
|
|
@@ -302,7 +309,8 @@ class AccuracyMetric(MetricBase): |
|
|
|
self.total = 0 |
|
|
|
return evaluate_result |
|
|
|
|
|
|
|
def bmes_tag_to_spans(tags, ignore_labels=None): |
|
|
|
|
|
|
|
def _bmes_tag_to_spans(tags, ignore_labels=None): |
|
|
|
""" |
|
|
|
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 |
|
|
|
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) |
|
|
@@ -330,7 +338,8 @@ def bmes_tag_to_spans(tags, ignore_labels=None): |
|
|
|
if span[0] not in ignore_labels |
|
|
|
] |
|
|
|
|
|
|
|
def bmeso_tag_to_spans(tags, ignore_labels=None): |
|
|
|
|
|
|
|
def _bmeso_tag_to_spans(tags, ignore_labels=None): |
|
|
|
""" |
|
|
|
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 |
|
|
|
返回[('singer', (1, 4))] (左闭右开区间) |
|
|
@@ -360,7 +369,8 @@ def bmeso_tag_to_spans(tags, ignore_labels=None): |
|
|
|
if span[0] not in ignore_labels |
|
|
|
] |
|
|
|
|
|
|
|
def bio_tag_to_spans(tags, ignore_labels=None): |
|
|
|
|
|
|
|
def _bio_tag_to_spans(tags, ignore_labels=None): |
|
|
|
""" |
|
|
|
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 |
|
|
|
返回[('singer', (1, 4))] (左闭右开区间) |
|
|
@@ -385,9 +395,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): |
|
|
|
else: |
|
|
|
spans.append((label, [idx, idx])) |
|
|
|
prev_bio_tag = bio_tag |
|
|
|
return [(span[0], (span[1][0], span[1][1]+1)) |
|
|
|
for span in spans |
|
|
|
if span[0] not in ignore_labels] |
|
|
|
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels] |
|
|
|
|
|
|
|
|
|
|
|
class SpanFPreRecMetric(MetricBase): |
|
|
@@ -416,23 +424,23 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
} |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, |
|
|
|
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, |
|
|
|
only_gross=True, f_type='micro', beta=1): |
|
|
|
""" |
|
|
|
|
|
|
|
:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), |
|
|
|
:param Vocabulary tag_vocab: 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), |
|
|
|
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. |
|
|
|
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
|
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
|
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 |
|
|
|
:param encoding_type: str, 目前支持bio, bmes |
|
|
|
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 |
|
|
|
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
|
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
|
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 |
|
|
|
:param str encoding_type: 目前支持bio, bmes |
|
|
|
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 |
|
|
|
个label |
|
|
|
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 |
|
|
|
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 |
|
|
|
label的f1, pre, rec |
|
|
|
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': |
|
|
|
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': |
|
|
|
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) |
|
|
|
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 |
|
|
|
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 |
|
|
|
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 |
|
|
|
""" |
|
|
|
encoding_type = encoding_type.lower() |
|
|
@@ -444,11 +452,11 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
|
|
|
|
self.encoding_type = encoding_type |
|
|
|
if self.encoding_type == 'bmes': |
|
|
|
self.tag_to_span_func = bmes_tag_to_spans |
|
|
|
self.tag_to_span_func = _bmes_tag_to_spans |
|
|
|
elif self.encoding_type == 'bio': |
|
|
|
self.tag_to_span_func = bio_tag_to_spans |
|
|
|
self.tag_to_span_func = _bio_tag_to_spans |
|
|
|
elif self.encoding_type == 'bmeso': |
|
|
|
self.tag_to_span_func = bmeso_tag_to_spans |
|
|
|
self.tag_to_span_func = _bmeso_tag_to_spans |
|
|
|
else: |
|
|
|
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") |
|
|
|
|
|
|
@@ -459,7 +467,7 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
self.only_gross = only_gross |
|
|
|
|
|
|
|
super().__init__() |
|
|
|
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) |
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
|
|
|
|
self.tag_vocab = tag_vocab |
|
|
|
|
|
|
@@ -467,12 +475,12 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
self._false_positives = defaultdict(int) |
|
|
|
self._false_negatives = defaultdict(int) |
|
|
|
|
|
|
|
def evaluate(self, pred, target, seq_lens): |
|
|
|
""" |
|
|
|
A lot of design idea comes from allennlp's measure |
|
|
|
:param pred: |
|
|
|
:param target: |
|
|
|
:param seq_lens: |
|
|
|
def evaluate(self, pred, target, seq_len): |
|
|
|
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 |
|
|
|
|
|
|
|
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 |
|
|
|
:param target: [batch, seq_len], 真实值 |
|
|
|
:param seq_len: [batch] 文本长度标记 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not isinstance(pred, torch.Tensor): |
|
|
@@ -482,9 +490,9 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if not isinstance(seq_lens, torch.Tensor): |
|
|
|
if not isinstance(seq_len, torch.Tensor): |
|
|
|
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(seq_lens)}.") |
|
|
|
f"got {type(seq_len)}.") |
|
|
|
|
|
|
|
if pred.size() == target.size() and len(target.size()) == 2: |
|
|
|
pass |
|
|
@@ -501,8 +509,8 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
|
|
|
|
batch_size = pred.size(0) |
|
|
|
for i in range(batch_size): |
|
|
|
pred_tags = pred[i, :int(seq_lens[i])].tolist() |
|
|
|
gold_tags = target[i, :int(seq_lens[i])].tolist() |
|
|
|
pred_tags = pred[i, :int(seq_len[i])].tolist() |
|
|
|
gold_tags = target[i, :int(seq_len[i])].tolist() |
|
|
|
|
|
|
|
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] |
|
|
|
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] |
|
|
@@ -520,8 +528,9 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
self._false_negatives[span[0]] += 1 |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" |
|
|
|
evaluate_result = {} |
|
|
|
if not self.only_gross or self.f_type=='macro': |
|
|
|
if not self.only_gross or self.f_type == 'macro': |
|
|
|
tags = set(self._false_negatives.keys()) |
|
|
|
tags.update(set(self._false_positives.keys())) |
|
|
|
tags.update(set(self._true_positives.keys())) |
|
|
@@ -578,6 +587,7 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
|
|
|
|
return f, pre, rec |
|
|
|
|
|
|
|
|
|
|
|
class BMESF1PreRecMetric(MetricBase): |
|
|
|
""" |
|
|
|
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, |
|
|
@@ -607,7 +617,7 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): |
|
|
|
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_len=None): |
|
|
|
""" |
|
|
|
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 |
|
|
|
|
|
|
@@ -617,11 +627,11 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
:param s_idx: int, Single标签所对应的tag idx |
|
|
|
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
|
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
|
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 |
|
|
|
:param seq_len: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_len'取数据。 |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) |
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
|
|
|
|
self.yt_wordnum = 0 |
|
|
|
self.yp_wordnum = 0 |
|
|
@@ -644,7 +654,7 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
""" |
|
|
|
给定一个tag的Tensor,返回合法tag |
|
|
|
|
|
|
|
:param tags: Tensor, shape: (seq_len, ) |
|
|
|
:param torch.Tensor tags: [seq_len] |
|
|
|
:return: 返回修改为合法tag的list |
|
|
|
""" |
|
|
|
assert len(tags)!=0 |
|
|
@@ -663,7 +673,14 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
|
|
|
|
return padded_tags[1:-1] |
|
|
|
|
|
|
|
def evaluate(self, pred, target, seq_lens): |
|
|
|
def evaluate(self, pred, target, seq_len): |
|
|
|
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 |
|
|
|
|
|
|
|
:param pred: [batch, seq_len] 或者 [batch, seq_len, 4] |
|
|
|
:param target: [batch, seq_len] |
|
|
|
:param seq_len: [batch] |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not isinstance(pred, torch.Tensor): |
|
|
|
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(pred)}.") |
|
|
@@ -671,9 +688,9 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(target)}.") |
|
|
|
|
|
|
|
if not isinstance(seq_lens, torch.Tensor): |
|
|
|
if not isinstance(seq_len, torch.Tensor): |
|
|
|
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," |
|
|
|
f"got {type(seq_lens)}.") |
|
|
|
f"got {type(seq_len)}.") |
|
|
|
|
|
|
|
if pred.size() == target.size() and len(target.size()) == 2: |
|
|
|
pass |
|
|
@@ -685,7 +702,7 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
f"{pred.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
for idx in range(len(pred)): |
|
|
|
seq_len = seq_lens[idx] |
|
|
|
seq_len = seq_len[idx] |
|
|
|
target_tags = target[idx][:seq_len].tolist() |
|
|
|
pred_tags = pred[idx][:seq_len] |
|
|
|
pred_tags = self._validate_tags(pred_tags) |
|
|
@@ -704,6 +721,7 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
self.yp_wordnum += 1 |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" |
|
|
|
P = self.corr_num / (self.yp_wordnum + 1e-12) |
|
|
|
R = self.corr_num / (self.yt_wordnum + 1e-12) |
|
|
|
F = 2 * P * R / (P + R + 1e-12) |
|
|
@@ -746,7 +764,7 @@ def _prepare_metrics(metrics): |
|
|
|
return _metrics |
|
|
|
|
|
|
|
|
|
|
|
def accuracy_topk(y_true, y_prob, k=1): |
|
|
|
def _accuracy_topk(y_true, y_prob, k=1): |
|
|
|
"""Compute accuracy of y_true matching top-k probable labels in y_prob. |
|
|
|
|
|
|
|
:param y_true: ndarray, true label, [n_samples] |
|
|
@@ -762,7 +780,7 @@ def accuracy_topk(y_true, y_prob, k=1): |
|
|
|
return acc |
|
|
|
|
|
|
|
|
|
|
|
def pred_topk(y_prob, k=1): |
|
|
|
def _pred_topk(y_prob, k=1): |
|
|
|
"""Return top-k predicted labels and corresponding probabilities. |
|
|
|
|
|
|
|
:param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels |
|
|
@@ -781,22 +799,24 @@ def pred_topk(y_prob, k=1): |
|
|
|
|
|
|
|
|
|
|
|
class SQuADMetric(MetricBase): |
|
|
|
"""SQuAD数据集metric |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None, |
|
|
|
beta=1, right_open=False, print_predict_stat=False): |
|
|
|
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, |
|
|
|
beta=1, right_open=True, print_predict_stat=False): |
|
|
|
""" |
|
|
|
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 |
|
|
|
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` |
|
|
|
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` |
|
|
|
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` |
|
|
|
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` |
|
|
|
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 |
|
|
|
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 |
|
|
|
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 |
|
|
|
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 |
|
|
|
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 |
|
|
|
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 |
|
|
|
""" |
|
|
|
super(SQuADMetric, self).__init__() |
|
|
|
|
|
|
|
self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end) |
|
|
|
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) |
|
|
|
|
|
|
|
self.print_predict_stat = print_predict_stat |
|
|
|
|
|
|
@@ -817,18 +837,29 @@ class SQuADMetric(MetricBase): |
|
|
|
|
|
|
|
self.right_open = right_open |
|
|
|
|
|
|
|
def evaluate(self, pred_start, pred_end, target_start, target_end): |
|
|
|
""" |
|
|
|
def evaluate(self, pred1, pred2, target1, target2): |
|
|
|
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 |
|
|
|
|
|
|
|
:param pred_start: [batch, seq_len] |
|
|
|
:param pred_end: [batch, seq_len] |
|
|
|
:param target_start: [batch] |
|
|
|
:param target_end: [batch] |
|
|
|
:param labels: [batch] |
|
|
|
:return: |
|
|
|
:param pred1: [batch]或者[batch, seq_len], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
:param pred2: [batch]或者[batch, seq_len] 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
:param target1: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 |
|
|
|
:param target2: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) |
|
|
|
:return: None |
|
|
|
""" |
|
|
|
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() |
|
|
|
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() |
|
|
|
pred_start = pred1 |
|
|
|
pred_end = pred2 |
|
|
|
target_start = target1 |
|
|
|
target_end = target2 |
|
|
|
|
|
|
|
if len(pred_start.size()) == 2: |
|
|
|
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() |
|
|
|
else: |
|
|
|
start_inference = pred_start.cpu().tolist() |
|
|
|
if len(pred_end.size()) == 2: |
|
|
|
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() |
|
|
|
else: |
|
|
|
end_inference = pred_end.cpu().tolist() |
|
|
|
|
|
|
|
start, end = [], [] |
|
|
|
max_len = pred_start.size(1) |
|
|
|
t_start = target_start.cpu().tolist() |
|
|
@@ -873,6 +904,7 @@ class SQuADMetric(MetricBase): |
|
|
|
self.has_ans_f += f |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" |
|
|
|
evaluate_result = {} |
|
|
|
|
|
|
|
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: |
|
|
|