|
@@ -23,6 +23,7 @@ from .utils import _get_func_signature |
|
|
from .utils import seq_len_to_mask |
|
|
from .utils import seq_len_to_mask |
|
|
from .vocabulary import Vocabulary |
|
|
from .vocabulary import Vocabulary |
|
|
from abc import abstractmethod |
|
|
from abc import abstractmethod |
|
|
|
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetricBase(object): |
|
|
class MetricBase(object): |
|
@@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None): |
|
|
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] |
|
|
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): |
|
|
|
|
|
""" |
|
|
|
|
|
检查vocab中的tag是否与encoding_type是匹配的 |
|
|
|
|
|
|
|
|
|
|
|
:param vocab: target的Vocabulary |
|
|
|
|
|
:param encoding_type: bio, bmes, bioes, bmeso |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
tag_set = set() |
|
|
|
|
|
for tag, idx in vocab: |
|
|
|
|
|
if idx in (vocab.unknown_idx, vocab.padding_idx): |
|
|
|
|
|
continue |
|
|
|
|
|
tag = tag[:1].lower() |
|
|
|
|
|
tag_set.add(tag) |
|
|
|
|
|
tags = encoding_type |
|
|
|
|
|
for tag in tag_set: |
|
|
|
|
|
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ |
|
|
|
|
|
f"encoding_type." |
|
|
|
|
|
tags = tags.replace(tag, '') # 删除该值 |
|
|
|
|
|
if tags: # 如果不为空,说明出现了未使用的tag |
|
|
|
|
|
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " |
|
|
|
|
|
"encoding_type.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpanFPreRecMetric(MetricBase): |
|
|
class SpanFPreRecMetric(MetricBase): |
|
|
r""" |
|
|
r""" |
|
|
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` |
|
|
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` |
|
@@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase): |
|
|
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) |
|
|
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) |
|
|
|
|
|
|
|
|
self.encoding_type = encoding_type |
|
|
self.encoding_type = encoding_type |
|
|
|
|
|
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) |
|
|
if self.encoding_type == 'bmes': |
|
|
if self.encoding_type == 'bmes': |
|
|
self.tag_to_span_func = _bmes_tag_to_spans |
|
|
self.tag_to_span_func = _bmes_tag_to_spans |
|
|
elif self.encoding_type == 'bio': |
|
|
elif self.encoding_type == 'bio': |
|
|