From 55e736bf4c9020ce404400b605d1c2febd8d0766 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 28 Aug 2019 23:53:20 +0800 Subject: [PATCH 1/2] =?UTF-8?q?SpanFMetric=E5=A2=9E=E5=8A=A0=E5=AF=B9encod?= =?UTF-8?q?ing=5Ftype=E5=92=8Ctag=5Fvocab=E7=9A=84=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 26 ++++++++++++++++++++++++++ test/core/test_metrics.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 1d1e3819..28d88fbc 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -23,6 +23,7 @@ from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary from abc import abstractmethod +import warnings class MetricBase(object): @@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None): return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] +def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): + """ + 检查vocab中的tag是否与encoding_type是匹配的 + + :param vocab: target的Vocabulary + :param encoding_type: bio, bmes, bioes, bmeso + :return: + """ + tag_set = set() + for tag, idx in vocab: + if idx in (vocab.unknown_idx, vocab.padding_idx): + continue + tag = tag[:1] + tag_set.add(tag) + tags = encoding_type + for tag in tag_set: + assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ + f"encoding_type." + tags = tags.replace(tag, '') # 删除该值 + if tags: # 如果不为空,说明出现了未使用的tag + warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " + "encoding_type.") + + class SpanFPreRecMetric(MetricBase): r""" 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` @@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) self.encoding_type = encoding_type + _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) if self.encoding_type == 'bmes': self.tag_to_span_func = _bmes_tag_to_spans elif self.encoding_type == 'bio': diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 236066d6..5a7c55cf 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -338,6 +338,41 @@ class SpanF1PreRecMetric(unittest.TestCase): for key, value in expected_metric.items(): self.assertAlmostEqual(value, metric_value[key], places=5) + def test_encoding_type(self): + # 检查传入的tag_vocab与encoding_type不符合时,是否会报错 + vocabs = {} + import random + from itertools import product + for encoding_type in ['bio', 'bioes', 'bmeso']: + vocab = Vocabulary(unknown=None, padding=None) + for i in range(random.randint(10, 100)): + label = str(random.randint(1, 10)) + for tag in encoding_type: + if tag!='o': + vocab.add_word(f'{tag}-{label}') + else: + vocab.add_word('o') + vocabs[encoding_type] = vocab + for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): + with self.subTest(e1=e1, e2=e2): + if e1==e2: + metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) + else: + s2 = set(e2) + s2.update(set(e1)) + if s2==set(e2): + continue + with self.assertRaises(AssertionError): + metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) + for encoding_type in ['bio', 'bioes', 'bmeso']: + with self.assertRaises(AssertionError): + metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes') + + with self.assertWarns(Warning): + vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) + metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') + vocab = Vocabulary().add_word_lst(list('bmes')) + metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数 From cbe5b347e54ce5181887743c62b06aabcd00b778 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 28 Aug 2019 23:53:53 +0800 Subject: [PATCH 2/2] =?UTF-8?q?SpanFMetric=E5=A2=9E=E5=8A=A0=E5=AF=B9encod?= =?UTF-8?q?ing=5Ftype=E5=92=8Ctag=5Fvocab=E7=9A=84=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 28d88fbc..0dc601a3 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -505,7 +505,7 @@ def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): for tag, idx in vocab: if idx in (vocab.unknown_idx, vocab.padding_idx): continue - tag = tag[:1] + tag = tag[:1].lower() tag_set.add(tag) tags = encoding_type for tag in tag_set: