@@ -23,6 +23,7 @@ from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
import warnings | |||
class MetricBase(object): | |||
@@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): | |||
""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param vocab: target的Vocabulary | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
for tag, idx in vocab: | |||
if idx in (vocab.unknown_idx, vocab.padding_idx): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
class SpanFPreRecMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
@@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
@@ -338,6 +338,41 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
for key, value in expected_metric.items(): | |||
self.assertAlmostEqual(value, metric_value[key], places=5) | |||
def test_encoding_type(self): | |||
# 检查传入的tag_vocab与encoding_type不符合时,是否会报错 | |||
vocabs = {} | |||
import random | |||
from itertools import product | |||
for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for i in range(random.randint(10, 100)): | |||
label = str(random.randint(1, 10)) | |||
for tag in encoding_type: | |||
if tag!='o': | |||
vocab.add_word(f'{tag}-{label}') | |||
else: | |||
vocab.add_word('o') | |||
vocabs[encoding_type] = vocab | |||
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | |||
with self.subTest(e1=e1, e2=e2): | |||
if e1==e2: | |||
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) | |||
else: | |||
s2 = set(e2) | |||
s2.update(set(e1)) | |||
if s2==set(e2): | |||
continue | |||
with self.assertRaises(AssertionError): | |||
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) | |||
for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
with self.assertRaises(AssertionError): | |||
metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes') | |||
with self.assertWarns(Warning): | |||
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | |||
vocab = Vocabulary().add_word_lst(list('bmes')) | |||
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | |||
class TestUsefulFunctions(unittest.TestCase): | |||
# 测试metrics.py中一些看上去挺有用的函数 | |||