|
|
@@ -3,13 +3,8 @@ from collections import Counter |
|
|
|
|
|
|
|
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 |
|
|
|
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 |
|
|
|
DEFAULT_RESERVED_LABEL = ['<reserved-2>', |
|
|
|
'<reserved-3>', |
|
|
|
'<reserved-4>'] # dict index = 2~4 |
|
|
|
|
|
|
|
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, |
|
|
|
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, |
|
|
|
DEFAULT_RESERVED_LABEL[2]: 4} |
|
|
|
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} |
|
|
|
|
|
|
|
|
|
|
|
def isiterable(p_object): |
|
|
@@ -58,24 +53,23 @@ class Vocabulary(object): |
|
|
|
self.word2idx = None |
|
|
|
self.idx2word = None |
|
|
|
|
|
|
|
def update(self, word): |
|
|
|
def update(self, word_lst): |
|
|
|
"""add word or list of words into Vocabulary |
|
|
|
|
|
|
|
:param word: a list of string or a single string |
|
|
|
""" |
|
|
|
if not isinstance(word, str) and isiterable(word): |
|
|
|
# it's a nested list |
|
|
|
for w in word: |
|
|
|
self.update(w) |
|
|
|
else: |
|
|
|
# it's a word to be added |
|
|
|
self.word_count[word] += 1 |
|
|
|
self.word2idx = None |
|
|
|
return self |
|
|
|
self.word_count.update(word_lst) |
|
|
|
|
|
|
|
|
|
|
|
def add(self, word): |
|
|
|
self.word_count[word] += 1 |
|
|
|
|
|
|
|
def add_word(self, word): |
|
|
|
self.add(word) |
|
|
|
|
|
|
|
def add_word_lst(self, word_lst): |
|
|
|
self.update(word_lst) |
|
|
|
|
|
|
|
def update_list(self, sent): |
|
|
|
self.word_count.update(sent) |
|
|
|
self.word2idx = None |
|
|
|
|
|
|
|
def build_vocab(self): |
|
|
|
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` |
|
|
|