From 6427e85e8f7540cf60203dab16a0a4f04ce9b5ef Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 15:44:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8D=87=E7=BA=A7Vocab=EF=BC=9A=20*=20?= =?UTF-8?q?=E5=A2=9E=E9=87=8F=E6=B7=BB=E5=8A=A0=E5=8D=95=E8=AF=8D=E5=88=B0?= =?UTF-8?q?=E8=AF=8D=E5=85=B8=E4=B8=AD=20*=20lazy=20update:=20=E5=BD=93?= =?UTF-8?q?=E7=94=A8=E5=88=B0=E8=AF=8D=E5=85=B8=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E6=89=8D=E9=87=8D=E6=96=B0build=20*=20=E5=BD=93=E6=96=B0?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=9A=84=E8=AF=8D=E5=AF=BC=E8=87=B4=E8=AF=8D?= =?UTF-8?q?=E5=85=B8=E5=A4=A7=E5=B0=8F=E8=B6=85=E5=87=BA=E9=99=90=E5=88=B6?= =?UTF-8?q?=E6=97=B6=EF=BC=8C=E6=89=93=E5=8D=B0=E4=B8=80=E4=B8=AAwarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update Vocabulary: * More words can be added after the building. * Lazy update: rebuild automatically when vocab is used. * print warning when max size is reached --- fastNLP/core/vocabulary.py | 30 ++++++++++++++++++++++++++++-- test/core/test_vocabulary.py | 27 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 7b0ab614..ca6b4ebf 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -16,14 +16,35 @@ def isiterable(p_object): def check_build_vocab(func): + """A decorator to make sure the indexing is built before used. + + """ + def _wrapper(self, *args, **kwargs): - if self.word2idx is None: + if self.word2idx is None or self.rebuild is True: self.build_vocab() return func(self, *args, **kwargs) return _wrapper +def check_build_status(func): + """A decorator to check whether the vocabulary updates after the last build. + + """ + + def _wrapper(self, *args, **kwargs): + if self.rebuild is False: + self.rebuild = True + if self.max_size is not None and len(self.word_count) >= self.max_size: + print("[Warning] Vocabulary has reached the max size {} when calling {} method. " + "Adding more words may cause unexpected behaviour of Vocabulary. ".format( + self.max_size, func.__name__)) + return func(self, *args, **kwargs) + + return _wrapper + + class Vocabulary(object): """Use for word and index one to one mapping @@ -54,7 +75,9 @@ class Vocabulary(object): self.unknown_label = None self.word2idx = None self.idx2word = None + self.rebuild = True + @check_build_status def update(self, word_lst): """Add a list of words into the vocabulary. @@ -62,6 +85,7 @@ class Vocabulary(object): """ self.word_count.update(word_lst) + @check_build_status def add(self, word): """Add a single word into the vocabulary. @@ -69,6 +93,7 @@ class Vocabulary(object): """ self.word_count[word] += 1 + @check_build_status def add_word(self, word): """Add a single word into the vocabulary. @@ -76,6 +101,7 @@ class Vocabulary(object): """ self.add(word) + @check_build_status def add_word_lst(self, word_lst): """Add a list of words into the vocabulary. @@ -101,6 +127,7 @@ class Vocabulary(object): start_idx = len(self.word2idx) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() + self.rebuild = False def build_reverse_vocab(self): """Build 'index to word' dict based on 'word to index' dict. @@ -188,4 +215,3 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() - diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index e140b1aa..e453e935 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -59,3 +59,30 @@ class TestIndexing(unittest.TestCase): vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + + +class TestOther(unittest.TestCase): + def test_additional_update(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + + _ = vocab["well"] + self.assertEqual(vocab.rebuild, False) + + vocab.add("hahaha") + self.assertEqual(vocab.rebuild, True) + + _ = vocab["hahaha"] + self.assertEqual(vocab.rebuild, False) + self.assertTrue("hahaha" in vocab) + + def test_warning(self): + vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) + vocab.update(text) + self.assertEqual(vocab.rebuild, True) + print(len(vocab)) + self.assertEqual(vocab.rebuild, False) + + vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) + # this will print a warning + self.assertEqual(vocab.rebuild, True)