|
|
@@ -28,15 +28,25 @@ class Vocabulary(object): |
|
|
|
vocab["word"] |
|
|
|
vocab.to_word(5) |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) |
|
|
|
self.padding_label = DEFAULT_PADDING_LABEL |
|
|
|
self.unknown_label = DEFAULT_UNKNOWN_LABEL |
|
|
|
def __init__(self, need_default=True): |
|
|
|
""" |
|
|
|
:param bool need_default: set if the Vocabulary has default labels reserved. |
|
|
|
""" |
|
|
|
if need_default: |
|
|
|
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) |
|
|
|
self.padding_label = DEFAULT_PADDING_LABEL |
|
|
|
self.unknown_label = DEFAULT_UNKNOWN_LABEL |
|
|
|
else: |
|
|
|
self.word2idx = {} |
|
|
|
self.padding_label = None |
|
|
|
self.unknown_label = None |
|
|
|
|
|
|
|
self.has_default = need_default |
|
|
|
self.idx2word = None |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.word2idx) |
|
|
|
|
|
|
|
|
|
|
|
def update(self, word): |
|
|
|
"""add word or list of words into Vocabulary |
|
|
|
|
|
|
@@ -73,9 +83,13 @@ class Vocabulary(object): |
|
|
|
return self[w] |
|
|
|
|
|
|
|
def unknown_idx(self): |
|
|
|
if self.unknown_label is None: |
|
|
|
return None |
|
|
|
return self.word2idx[self.unknown_label] |
|
|
|
|
|
|
|
def padding_idx(self): |
|
|
|
if self.padding_label is None: |
|
|
|
return None |
|
|
|
return self.word2idx[self.padding_label] |
|
|
|
|
|
|
|
def build_reverse_vocab(self): |
|
|
|