From e8cc702737f93afc285a90051f29788b40847523 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 20 Sep 2018 15:11:01 +0800 Subject: [PATCH] add default switch --- fastNLP/core/vocabulary.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 79b70939..ad618ff9 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -28,15 +28,25 @@ class Vocabulary(object): vocab["word"] vocab.to_word(5) """ - def __init__(self): - self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) - self.padding_label = DEFAULT_PADDING_LABEL - self.unknown_label = DEFAULT_UNKNOWN_LABEL + def __init__(self, need_default=True): + """ + :param bool need_default: set if the Vocabulary has default labels reserved. + """ + if need_default: + self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) + self.padding_label = DEFAULT_PADDING_LABEL + self.unknown_label = DEFAULT_UNKNOWN_LABEL + else: + self.word2idx = {} + self.padding_label = None + self.unknown_label = None + + self.has_default = need_default self.idx2word = None def __len__(self): return len(self.word2idx) - + def update(self, word): """add word or list of words into Vocabulary @@ -73,9 +83,13 @@ class Vocabulary(object): return self[w] def unknown_idx(self): + if self.unknown_label is None: + return None return self.word2idx[self.unknown_label] def padding_idx(self): + if self.padding_label is None: + return None return self.word2idx[self.padding_label] def build_reverse_vocab(self):