|
|
@@ -269,7 +269,7 @@ class AccuracyMetric(MetricBase): |
|
|
|
|
|
|
|
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` |
|
|
|
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` |
|
|
|
:param seq_len: 参数映射表中 `seq_lens` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` |
|
|
|
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` |
|
|
|
""" |
|
|
|
def __init__(self, pred=None, target=None, seq_len=None): |
|
|
|
|
|
|
@@ -458,7 +458,7 @@ class SpanFPreRecMetric(MetricBase): |
|
|
|
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. |
|
|
|
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 |
|
|
|
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 |
|
|
|
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 |
|
|
|
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 |
|
|
|
:param str encoding_type: 目前支持bio, bmes |
|
|
|
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 |
|
|
|
个label |
|
|
@@ -729,9 +729,8 @@ class BMESF1PreRecMetric(MetricBase): |
|
|
|
f"{pred.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
for idx in range(len(pred)): |
|
|
|
seq_len = seq_len[idx] |
|
|
|
target_tags = target[idx][:seq_len].tolist() |
|
|
|
pred_tags = pred[idx][:seq_len] |
|
|
|
target_tags = target[idx][:seq_len[idx]].tolist() |
|
|
|
pred_tags = pred[idx][:seq_len[idx]] |
|
|
|
pred_tags = self._validate_tags(pred_tags) |
|
|
|
start_idx = 0 |
|
|
|
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): |
|
|
|