@@ -16,14 +16,35 @@ def isiterable(p_object): | |||||
def check_build_vocab(func): | def check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | |||||
""" | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None: | |||||
if self.word2idx is None or self.rebuild is True: | |||||
self.build_vocab() | self.build_vocab() | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
def check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | |||||
""" | |||||
def _wrapper(self, *args, **kwargs): | |||||
if self.rebuild is False: | |||||
self.rebuild = True | |||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||||
self.max_size, func.__name__)) | |||||
return func(self, *args, **kwargs) | |||||
return _wrapper | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
"""Use for word and index one to one mapping | """Use for word and index one to one mapping | ||||
@@ -54,7 +75,9 @@ class Vocabulary(object): | |||||
self.unknown_label = None | self.unknown_label = None | ||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | |||||
@check_build_status | |||||
def update(self, word_lst): | def update(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | """Add a list of words into the vocabulary. | ||||
@@ -62,6 +85,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
@check_build_status | |||||
def add(self, word): | def add(self, word): | ||||
"""Add a single word into the vocabulary. | """Add a single word into the vocabulary. | ||||
@@ -69,6 +93,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
@check_build_status | |||||
def add_word(self, word): | def add_word(self, word): | ||||
"""Add a single word into the vocabulary. | """Add a single word into the vocabulary. | ||||
@@ -76,6 +101,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.add(word) | self.add(word) | ||||
@check_build_status | |||||
def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | """Add a list of words into the vocabulary. | ||||
@@ -101,6 +127,7 @@ class Vocabulary(object): | |||||
start_idx = len(self.word2idx) | start_idx = len(self.word2idx) | ||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
self.rebuild = False | |||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
"""Build 'index to word' dict based on 'word to index' dict. | """Build 'index to word' dict based on 'word to index' dict. | ||||
@@ -188,4 +215,3 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
@@ -1,3 +1,4 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
@@ -26,7 +27,7 @@ class EmbedLoader(BaseLoader): | |||||
emb = {} | emb = {} | ||||
with open(emb_file, 'r', encoding='utf-8') as f: | with open(emb_file, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) | |||||
if len(line) > 2: | if len(line) > 2: | ||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | ||||
return emb | return emb | ||||
@@ -35,9 +36,9 @@ class EmbedLoader(BaseLoader): | |||||
def _load_pretrain(emb_file, emb_type): | def _load_pretrain(emb_file, emb_type): | ||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | """Read txt data from embedding file and convert to np.array as pre-trained embedding | ||||
:param emb_file: str, the pre-trained embedding file path | |||||
:param emb_type: str, the pre-trained embedding data format | |||||
:return dict: {str: np.array} | |||||
:param str emb_file: the pre-trained embedding file path | |||||
:param str emb_type: the pre-trained embedding data format | |||||
:return dict embedding: `{str: np.array}` | |||||
""" | """ | ||||
if emb_type == 'glove': | if emb_type == 'glove': | ||||
return EmbedLoader._load_glove(emb_file) | return EmbedLoader._load_glove(emb_file) | ||||
@@ -45,38 +46,68 @@ class EmbedLoader(BaseLoader): | |||||
raise Exception("embedding type {} not support yet".format(emb_type)) | raise Exception("embedding type {} not support yet".format(emb_type)) | ||||
@staticmethod | @staticmethod | ||||
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): | |||||
def load_embedding(emb_dim, emb_file, emb_type, vocab): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | """Load the pre-trained embedding and combine with the given dictionary. | ||||
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param emb_file: str, the pre-trained embedding file path. | |||||
:param emb_type: str, the pre-trained embedding format, support glove now | |||||
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:param emb_pkl: str, the embedding pickle file. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param str emb_type: the pre-trained embedding format, support glove now | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | ||||
vocab: input vocab or vocab built by pre-train | vocab: input vocab or vocab built by pre-train | ||||
TODO: fragile code | |||||
""" | """ | ||||
# If the embedding pickle exists, load it and return. | |||||
# if os.path.exists(emb_pkl): | |||||
# with open(emb_pkl, "rb") as f: | |||||
# embedding_tensor, vocab = _pickle.load(f) | |||||
# return embedding_tensor, vocab | |||||
# Otherwise, load the pre-trained embedding. | |||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | ||||
if vocab is None: | if vocab is None: | ||||
# build vocabulary from pre-trained embedding | # build vocabulary from pre-trained embedding | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
for w in pretrain.keys(): | for w in pretrain.keys(): | ||||
vocab.update(w) | |||||
vocab.add(w) | |||||
embedding_tensor = torch.randn(len(vocab), emb_dim) | embedding_tensor = torch.randn(len(vocab), emb_dim) | ||||
for w, v in pretrain.items(): | for w, v in pretrain.items(): | ||||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | if len(v.shape) > 1 or emb_dim != v.shape[0]: | ||||
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||||
raise ValueError( | |||||
"Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,))) | |||||
if vocab.has_word(w): | if vocab.has_word(w): | ||||
embedding_tensor[vocab[w]] = v | embedding_tensor[vocab[w]] = v | ||||
# save and return the result | |||||
# with open(emb_pkl, "wb") as f: | |||||
# _pickle.dump((embedding_tensor, vocab), f) | |||||
return embedding_tensor, vocab | return embedding_tensor, vocab | ||||
@staticmethod | |||||
def parse_glove_line(line): | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(" "))) | |||||
if len(line) <= 2: | |||||
raise RuntimeError("something goes wrong in parsing glove embedding") | |||||
return line[0], torch.Tensor(list(map(float, line[1:]))) | |||||
@staticmethod | |||||
def fast_load_embedding(emb_dim, emb_file, vocab): | |||||
"""Fast load the pre-trained embedding and combine with the given dictionary. | |||||
This loading method uses line-by-line operation. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return numpy.ndarray embedding_matrix: | |||||
""" | |||||
if vocab is None: | |||||
raise RuntimeError("You must provide a vocabulary.") | |||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) | |||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | |||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
for line in f: | |||||
word, vector = EmbedLoader.parse_glove_line(line) | |||||
if word in vocab: | |||||
if len(vector.shape) > 1 or emb_dim != vector.shape[0]: | |||||
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) | |||||
embedding_matrix[vocab[word]] = vector | |||||
hit_flags[vocab[word]] = 1 | |||||
if np.sum(hit_flags) < len(vocab): | |||||
# some words from vocab are missing in pre-trained embedding | |||||
# we normally sample them | |||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||||
mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T) | |||||
sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),)) | |||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||||
return embedding_matrix |
@@ -59,3 +59,30 @@ class TestIndexing(unittest.TestCase): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | ||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
class TestOther(unittest.TestCase): | |||||
def test_additional_update(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
_ = vocab["well"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.add("hahaha") | |||||
self.assertEqual(vocab.rebuild, True) | |||||
_ = vocab["hahaha"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
self.assertTrue("hahaha" in vocab) | |||||
def test_warning(self): | |||||
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(vocab.rebuild, True) | |||||
print(len(vocab)) | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||||
# this will print a warning | |||||
self.assertEqual(vocab.rebuild, True) |
@@ -8,5 +8,3 @@ in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 | |||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | ||||
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | " 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | ||||
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | 's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | ||||
@@ -0,0 +1,12 @@ | |||||
import unittest | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
class TestEmbedLoader(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary() | |||||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||||
embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab) | |||||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) |