diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 5d9f2185..2f2358a1 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,4 +1,5 @@ from copy import deepcopy +from collections import Counter DEFAULT_PADDING_LABEL = '' # dict index = 0 DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 @@ -23,9 +24,6 @@ def check_build_vocab(func): def _wrapper(self, *args, **kwargs): if self.word2idx is None: self.build_vocab() - self.build_reverse_vocab() - elif self.idx2word is None: - self.build_reverse_vocab() return func(self, *args, **kwargs) return _wrapper @@ -49,7 +47,7 @@ class Vocabulary(object): """ self.max_size = max_size self.min_freq = min_freq - self.word_count = {} + self.word_count = Counter() self.has_default = need_default if self.has_default: self.padding_label = DEFAULT_PADDING_LABEL @@ -71,13 +69,14 @@ class Vocabulary(object): self.update(w) else: # it's a word to be added - if word not in self.word_count: - self.word_count[word] = 1 - else: - self.word_count[word] += 1 + self.word_count[word] += 1 self.word2idx = None return self + def update_list(self, sent): + self.word_count.update(sent) + self.word2idx = None + def build_vocab(self): """build 'word to index' dict, and filter the word using `max_size` and `min_freq` """ @@ -88,26 +87,25 @@ class Vocabulary(object): else: self.word2idx = {} - words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) + max_size = min(self.max_size, len(self.word_count)) if self.max_size else None + words = self.word_count.most_common(max_size) if self.min_freq is not None: - words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) - if self.max_size is not None and len(words) > self.max_size: - words = words[:self.max_size] - for w, _ in words: - self.word2idx[w] = len(self.word2idx) + words = filter(lambda kv: kv[1] >= self.min_freq, words) + start_idx = len(self.word2idx) + self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) + self.build_reverse_vocab() def build_reverse_vocab(self): """build 'index to word' dict based on 'word to index' dict """ - self.idx2word = {self.word2idx[w] : w for w in self.word2idx} + self.idx2word = {i: w for w, i in self.word2idx.items()} @check_build_vocab def __len__(self): return len(self.word2idx) - @check_build_vocab def has_word(self, w): - return w in self.word2idx + return self.__contains__(w) @check_build_vocab def __getitem__(self, w): @@ -122,14 +120,13 @@ class Vocabulary(object): else: raise ValueError("word {} not in vocabulary".format(w)) - @check_build_vocab def to_index(self, w): """ like to_index(w) function, turn a word to the index if w is not in Vocabulary, return the unknown label :param str w: """ - return self[w] + return self.__getitem__(w) @property @check_build_vocab @@ -140,7 +137,7 @@ class Vocabulary(object): def __setattr__(self, name, val): self.__dict__[name] = val - if name in self.__dict__ and name in ["unknown_label", "padding_label"]: + if name in ["unknown_label", "padding_label"]: self.word2idx = None @property @@ -156,8 +153,6 @@ class Vocabulary(object): :param int idx: """ - if self.idx2word is None: - self.build_reverse_vocab() return self.idx2word[idx] def __getstate__(self): @@ -172,12 +167,13 @@ class Vocabulary(object): """use to restore state from pickle """ self.__dict__.update(state) - self.idx2word = None + self.build_reverse_vocab() + @check_build_vocab def __contains__(self, item): """Check if a word in vocabulary. :param item: the word :return: True or False """ - return self.has_word(item) + return item in self.word2idx diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index fc2814c8..2cdfcab4 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -1,3 +1,6 @@ +import os +import _pickle as pickle + class BaseLoader(object): def __init__(self): @@ -9,12 +12,23 @@ class BaseLoader(object): text = f.readlines() return [line.strip() for line in text] - @staticmethod - def load(data_path): + @classmethod + def load(cls, data_path): with open(data_path, "r", encoding="utf-8") as f: text = f.readlines() return [[word for word in sent.strip()] for sent in text] + @classmethod + def load_with_cache(cls, data_path, cache_path): + if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): + with open(cache_path, 'rb') as f: + return pickle.load(f) + else: + obj = cls.load(data_path) + with open(cache_path, 'wb') as f: + pickle.dump(obj, f) + return obj + class ToyLoader0(BaseLoader): """