|
|
@@ -1,4 +1,5 @@ |
|
|
|
from copy import deepcopy |
|
|
|
from collections import Counter |
|
|
|
|
|
|
|
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 |
|
|
|
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 |
|
|
@@ -23,9 +24,6 @@ def check_build_vocab(func): |
|
|
|
def _wrapper(self, *args, **kwargs): |
|
|
|
if self.word2idx is None: |
|
|
|
self.build_vocab() |
|
|
|
self.build_reverse_vocab() |
|
|
|
elif self.idx2word is None: |
|
|
|
self.build_reverse_vocab() |
|
|
|
return func(self, *args, **kwargs) |
|
|
|
return _wrapper |
|
|
|
|
|
|
@@ -49,7 +47,7 @@ class Vocabulary(object): |
|
|
|
""" |
|
|
|
self.max_size = max_size |
|
|
|
self.min_freq = min_freq |
|
|
|
self.word_count = {} |
|
|
|
self.word_count = Counter() |
|
|
|
self.has_default = need_default |
|
|
|
if self.has_default: |
|
|
|
self.padding_label = DEFAULT_PADDING_LABEL |
|
|
@@ -71,13 +69,14 @@ class Vocabulary(object): |
|
|
|
self.update(w) |
|
|
|
else: |
|
|
|
# it's a word to be added |
|
|
|
if word not in self.word_count: |
|
|
|
self.word_count[word] = 1 |
|
|
|
else: |
|
|
|
self.word_count[word] += 1 |
|
|
|
self.word_count[word] += 1 |
|
|
|
self.word2idx = None |
|
|
|
return self |
|
|
|
|
|
|
|
def update_list(self, sent): |
|
|
|
self.word_count.update(sent) |
|
|
|
self.word2idx = None |
|
|
|
|
|
|
|
def build_vocab(self): |
|
|
|
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` |
|
|
|
""" |
|
|
@@ -88,26 +87,25 @@ class Vocabulary(object): |
|
|
|
else: |
|
|
|
self.word2idx = {} |
|
|
|
|
|
|
|
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) |
|
|
|
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None |
|
|
|
words = self.word_count.most_common(max_size) |
|
|
|
if self.min_freq is not None: |
|
|
|
words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) |
|
|
|
if self.max_size is not None and len(words) > self.max_size: |
|
|
|
words = words[:self.max_size] |
|
|
|
for w, _ in words: |
|
|
|
self.word2idx[w] = len(self.word2idx) |
|
|
|
words = filter(lambda kv: kv[1] >= self.min_freq, words) |
|
|
|
start_idx = len(self.word2idx) |
|
|
|
self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) |
|
|
|
self.build_reverse_vocab() |
|
|
|
|
|
|
|
def build_reverse_vocab(self): |
|
|
|
"""build 'index to word' dict based on 'word to index' dict |
|
|
|
""" |
|
|
|
self.idx2word = {self.word2idx[w] : w for w in self.word2idx} |
|
|
|
self.idx2word = {i: w for w, i in self.word2idx.items()} |
|
|
|
|
|
|
|
@check_build_vocab |
|
|
|
def __len__(self): |
|
|
|
return len(self.word2idx) |
|
|
|
|
|
|
|
@check_build_vocab |
|
|
|
def has_word(self, w): |
|
|
|
return w in self.word2idx |
|
|
|
return self.__contains__(w) |
|
|
|
|
|
|
|
@check_build_vocab |
|
|
|
def __getitem__(self, w): |
|
|
@@ -122,14 +120,13 @@ class Vocabulary(object): |
|
|
|
else: |
|
|
|
raise ValueError("word {} not in vocabulary".format(w)) |
|
|
|
|
|
|
|
@check_build_vocab |
|
|
|
def to_index(self, w): |
|
|
|
""" like to_index(w) function, turn a word to the index |
|
|
|
if w is not in Vocabulary, return the unknown label |
|
|
|
|
|
|
|
:param str w: |
|
|
|
""" |
|
|
|
return self[w] |
|
|
|
return self.__getitem__(w) |
|
|
|
|
|
|
|
@property |
|
|
|
@check_build_vocab |
|
|
@@ -140,7 +137,7 @@ class Vocabulary(object): |
|
|
|
|
|
|
|
def __setattr__(self, name, val): |
|
|
|
self.__dict__[name] = val |
|
|
|
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: |
|
|
|
if name in ["unknown_label", "padding_label"]: |
|
|
|
self.word2idx = None |
|
|
|
|
|
|
|
@property |
|
|
@@ -156,8 +153,6 @@ class Vocabulary(object): |
|
|
|
|
|
|
|
:param int idx: |
|
|
|
""" |
|
|
|
if self.idx2word is None: |
|
|
|
self.build_reverse_vocab() |
|
|
|
return self.idx2word[idx] |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
@@ -172,12 +167,13 @@ class Vocabulary(object): |
|
|
|
"""use to restore state from pickle |
|
|
|
""" |
|
|
|
self.__dict__.update(state) |
|
|
|
self.idx2word = None |
|
|
|
self.build_reverse_vocab() |
|
|
|
|
|
|
|
@check_build_vocab |
|
|
|
def __contains__(self, item): |
|
|
|
"""Check if a word in vocabulary. |
|
|
|
|
|
|
|
:param item: the word |
|
|
|
:return: True or False |
|
|
|
""" |
|
|
|
return self.has_word(item) |
|
|
|
return item in self.word2idx |