Browse Source

use counter in vocab, add a load func in baseloader

tags/v0.2.0
yunfan 6 years ago
parent
commit
3a42c84a47
2 changed files with 36 additions and 26 deletions
  1. +20
    -24
      fastNLP/core/vocabulary.py
  2. +16
    -2
      fastNLP/io/base_loader.py

+ 20
- 24
fastNLP/core/vocabulary.py View File

@@ -1,4 +1,5 @@
from copy import deepcopy
from collections import Counter

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
@@ -23,9 +24,6 @@ def check_build_vocab(func):
def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
self.build_vocab()
self.build_reverse_vocab()
elif self.idx2word is None:
self.build_reverse_vocab()
return func(self, *args, **kwargs)
return _wrapper

@@ -49,7 +47,7 @@ class Vocabulary(object):
"""
self.max_size = max_size
self.min_freq = min_freq
self.word_count = {}
self.word_count = Counter()
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
@@ -71,13 +69,14 @@ class Vocabulary(object):
self.update(w)
else:
# it's a word to be added
if word not in self.word_count:
self.word_count[word] = 1
else:
self.word_count[word] += 1
self.word_count[word] += 1
self.word2idx = None
return self

def update_list(self, sent):
self.word_count.update(sent)
self.word2idx = None

def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq`
"""
@@ -88,26 +87,25 @@ class Vocabulary(object):
else:
self.word2idx = {}

words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True)
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size)
if self.min_freq is not None:
words = list(filter(lambda kv: kv[1] >= self.min_freq, words))
if self.max_size is not None and len(words) > self.max_size:
words = words[:self.max_size]
for w, _ in words:
self.word2idx[w] = len(self.word2idx)
words = filter(lambda kv: kv[1] >= self.min_freq, words)
start_idx = len(self.word2idx)
self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)})
self.build_reverse_vocab()

def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
self.idx2word = {i: w for w, i in self.word2idx.items()}

@check_build_vocab
def __len__(self):
return len(self.word2idx)

@check_build_vocab
def has_word(self, w):
return w in self.word2idx
return self.__contains__(w)

@check_build_vocab
def __getitem__(self, w):
@@ -122,14 +120,13 @@ class Vocabulary(object):
else:
raise ValueError("word {} not in vocabulary".format(w))

@check_build_vocab
def to_index(self, w):
""" like to_index(w) function, turn a word to the index
if w is not in Vocabulary, return the unknown label

:param str w:
"""
return self[w]
return self.__getitem__(w)

@property
@check_build_vocab
@@ -140,7 +137,7 @@ class Vocabulary(object):

def __setattr__(self, name, val):
self.__dict__[name] = val
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
if name in ["unknown_label", "padding_label"]:
self.word2idx = None

@property
@@ -156,8 +153,6 @@ class Vocabulary(object):

:param int idx:
"""
if self.idx2word is None:
self.build_reverse_vocab()
return self.idx2word[idx]

def __getstate__(self):
@@ -172,12 +167,13 @@ class Vocabulary(object):
"""use to restore state from pickle
"""
self.__dict__.update(state)
self.idx2word = None
self.build_reverse_vocab()

@check_build_vocab
def __contains__(self, item):
"""Check if a word in vocabulary.

:param item: the word
:return: True or False
"""
return self.has_word(item)
return item in self.word2idx

+ 16
- 2
fastNLP/io/base_loader.py View File

@@ -1,3 +1,6 @@
import os
import _pickle as pickle

class BaseLoader(object):

def __init__(self):
@@ -9,12 +12,23 @@ class BaseLoader(object):
text = f.readlines()
return [line.strip() for line in text]

@staticmethod
def load(data_path):
@classmethod
def load(cls, data_path):
with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines()
return [[word for word in sent.strip()] for sent in text]

@classmethod
def load_with_cache(cls, data_path, cache_path):
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path):
with open(cache_path, 'rb') as f:
return pickle.load(f)
else:
obj = cls.load(data_path)
with open(cache_path, 'wb') as f:
pickle.dump(obj, f)
return obj


class ToyLoader0(BaseLoader):
"""


Loading…
Cancel
Save