diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 7b0ab614..ca6b4ebf 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -16,14 +16,35 @@ def isiterable(p_object): def check_build_vocab(func): + """A decorator to make sure the indexing is built before used. + + """ + def _wrapper(self, *args, **kwargs): - if self.word2idx is None: + if self.word2idx is None or self.rebuild is True: self.build_vocab() return func(self, *args, **kwargs) return _wrapper +def check_build_status(func): + """A decorator to check whether the vocabulary updates after the last build. + + """ + + def _wrapper(self, *args, **kwargs): + if self.rebuild is False: + self.rebuild = True + if self.max_size is not None and len(self.word_count) >= self.max_size: + print("[Warning] Vocabulary has reached the max size {} when calling {} method. " + "Adding more words may cause unexpected behaviour of Vocabulary. ".format( + self.max_size, func.__name__)) + return func(self, *args, **kwargs) + + return _wrapper + + class Vocabulary(object): """Use for word and index one to one mapping @@ -54,7 +75,9 @@ class Vocabulary(object): self.unknown_label = None self.word2idx = None self.idx2word = None + self.rebuild = True + @check_build_status def update(self, word_lst): """Add a list of words into the vocabulary. @@ -62,6 +85,7 @@ class Vocabulary(object): """ self.word_count.update(word_lst) + @check_build_status def add(self, word): """Add a single word into the vocabulary. @@ -69,6 +93,7 @@ class Vocabulary(object): """ self.word_count[word] += 1 + @check_build_status def add_word(self, word): """Add a single word into the vocabulary. @@ -76,6 +101,7 @@ class Vocabulary(object): """ self.add(word) + @check_build_status def add_word_lst(self, word_lst): """Add a list of words into the vocabulary. @@ -101,6 +127,7 @@ class Vocabulary(object): start_idx = len(self.word2idx) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() + self.rebuild = False def build_reverse_vocab(self): """Build 'index to word' dict based on 'word to index' dict. @@ -188,4 +215,3 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() - diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index e140b1aa..e453e935 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -59,3 +59,30 @@ class TestIndexing(unittest.TestCase): vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + + +class TestOther(unittest.TestCase): + def test_additional_update(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + + _ = vocab["well"] + self.assertEqual(vocab.rebuild, False) + + vocab.add("hahaha") + self.assertEqual(vocab.rebuild, True) + + _ = vocab["hahaha"] + self.assertEqual(vocab.rebuild, False) + self.assertTrue("hahaha" in vocab) + + def test_warning(self): + vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) + vocab.update(text) + self.assertEqual(vocab.rebuild, True) + print(len(vocab)) + self.assertEqual(vocab.rebuild, False) + + vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) + # this will print a warning + self.assertEqual(vocab.rebuild, True)