@@ -16,14 +16,35 @@ def isiterable(p_object):
def check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.
"""
def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
if self.word2idx is None or self.rebuild is True :
self.build_vocab()
return func(self, *args, **kwargs)
return _wrapper
def check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build.
"""
def _wrapper(self, *args, **kwargs):
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
return _wrapper
class Vocabulary(object):
"""Use for word and index one to one mapping
@@ -54,7 +75,9 @@ class Vocabulary(object):
self.unknown_label = None
self.word2idx = None
self.idx2word = None
self.rebuild = True
@check_build_status
def update(self, word_lst):
"""Add a list of words into the vocabulary.
@@ -62,6 +85,7 @@ class Vocabulary(object):
"""
self.word_count.update(word_lst)
@check_build_status
def add(self, word):
"""Add a single word into the vocabulary.
@@ -69,6 +93,7 @@ class Vocabulary(object):
"""
self.word_count[word] += 1
@check_build_status
def add_word(self, word):
"""Add a single word into the vocabulary.
@@ -76,6 +101,7 @@ class Vocabulary(object):
"""
self.add(word)
@check_build_status
def add_word_lst(self, word_lst):
"""Add a list of words into the vocabulary.
@@ -101,6 +127,7 @@ class Vocabulary(object):
start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab()
self.rebuild = False
def build_reverse_vocab(self):
"""Build 'index to word' dict based on 'word to index' dict.
@@ -188,4 +215,3 @@ class Vocabulary(object):
"""
self.__dict__.update(state)
self.build_reverse_vocab()