Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
xuyige 5 years ago
parent
commit
8142bad87a
2 changed files with 61 additions and 0 deletions
  1. +26
    -0
      fastNLP/core/metrics.py
  2. +35
    -0
      test/core/test_metrics.py

+ 26
- 0
fastNLP/core/metrics.py View File

@@ -23,6 +23,7 @@ from .utils import _get_func_signature
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary
from abc import abstractmethod
import warnings


class MetricBase(object):
@@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]


def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str):
"""
检查vocab中的tag是否与encoding_type是匹配的

:param vocab: target的Vocabulary
:param encoding_type: bio, bmes, bioes, bmeso
:return:
"""
tag_set = set()
for tag, idx in vocab:
if idx in (vocab.unknown_idx, vocab.padding_idx):
continue
tag = tag[:1].lower()
tag_set.add(tag)
tags = encoding_type
for tag in tag_set:
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
f"encoding_type."
tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空,说明出现了未使用的tag
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.")


class SpanFPreRecMetric(MetricBase):
r"""
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`
@@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
self.encoding_type = encoding_type
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans
elif self.encoding_type == 'bio':


+ 35
- 0
test/core/test_metrics.py View File

@@ -338,6 +338,41 @@ class SpanF1PreRecMetric(unittest.TestCase):
for key, value in expected_metric.items():
self.assertAlmostEqual(value, metric_value[key], places=5)

def test_encoding_type(self):
# 检查传入的tag_vocab与encoding_type不符合时,是否会报错
vocabs = {}
import random
from itertools import product
for encoding_type in ['bio', 'bioes', 'bmeso']:
vocab = Vocabulary(unknown=None, padding=None)
for i in range(random.randint(10, 100)):
label = str(random.randint(1, 10))
for tag in encoding_type:
if tag!='o':
vocab.add_word(f'{tag}-{label}')
else:
vocab.add_word('o')
vocabs[encoding_type] = vocab
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
with self.subTest(e1=e1, e2=e2):
if e1==e2:
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
else:
s2 = set(e2)
s2.update(set(e1))
if s2==set(e2):
continue
with self.assertRaises(AssertionError):
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
for encoding_type in ['bio', 'bioes', 'bmeso']:
with self.assertRaises(AssertionError):
metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes')

with self.assertWarns(Warning):
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
vocab = Vocabulary().add_word_lst(list('bmes'))
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')

class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数


Loading…
Cancel
Save