Browse Source

升级Vocab:

* 增量添加单词到词典中
* lazy update: 当用到词典的时候才重新build
* 当新添加的词导致词典大小超出限制时,打印一个warning

Update Vocabulary:
* More words can be added after the building.
* Lazy update: rebuild automatically when vocab is used.
* print warning when max size is reached
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
6427e85e8f
2 changed files with 55 additions and 2 deletions
  1. +28
    -2
      fastNLP/core/vocabulary.py
  2. +27
    -0
      test/core/test_vocabulary.py

+ 28
- 2
fastNLP/core/vocabulary.py View File

@@ -16,14 +16,35 @@ def isiterable(p_object):


def check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.

"""

def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
if self.word2idx is None or self.rebuild is True:
self.build_vocab()
return func(self, *args, **kwargs)

return _wrapper


def check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build.

"""

def _wrapper(self, *args, **kwargs):
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)

return _wrapper


class Vocabulary(object):
"""Use for word and index one to one mapping

@@ -54,7 +75,9 @@ class Vocabulary(object):
self.unknown_label = None
self.word2idx = None
self.idx2word = None
self.rebuild = True

@check_build_status
def update(self, word_lst):
"""Add a list of words into the vocabulary.

@@ -62,6 +85,7 @@ class Vocabulary(object):
"""
self.word_count.update(word_lst)

@check_build_status
def add(self, word):
"""Add a single word into the vocabulary.

@@ -69,6 +93,7 @@ class Vocabulary(object):
"""
self.word_count[word] += 1

@check_build_status
def add_word(self, word):
"""Add a single word into the vocabulary.

@@ -76,6 +101,7 @@ class Vocabulary(object):
"""
self.add(word)

@check_build_status
def add_word_lst(self, word_lst):
"""Add a list of words into the vocabulary.

@@ -101,6 +127,7 @@ class Vocabulary(object):
start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab()
self.rebuild = False

def build_reverse_vocab(self):
"""Build 'index to word' dict based on 'word to index' dict.
@@ -188,4 +215,3 @@ class Vocabulary(object):
"""
self.__dict__.update(state)
self.build_reverse_vocab()


+ 27
- 0
test/core/test_vocabulary.py View File

@@ -59,3 +59,30 @@ class TestIndexing(unittest.TestCase):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])


class TestOther(unittest.TestCase):
def test_additional_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)

_ = vocab["well"]
self.assertEqual(vocab.rebuild, False)

vocab.add("hahaha")
self.assertEqual(vocab.rebuild, True)

_ = vocab["hahaha"]
self.assertEqual(vocab.rebuild, False)
self.assertTrue("hahaha" in vocab)

def test_warning(self):
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None)
vocab.update(text)
self.assertEqual(vocab.rebuild, True)
print(len(vocab))
self.assertEqual(vocab.rebuild, False)

vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"])
# this will print a warning
self.assertEqual(vocab.rebuild, True)

Loading…
Cancel
Save