From 6427e85e8f7540cf60203dab16a0a4f04ce9b5ef Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 15:44:52 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=8D=87=E7=BA=A7Vocab=EF=BC=9A=20*=20?= =?UTF-8?q?=E5=A2=9E=E9=87=8F=E6=B7=BB=E5=8A=A0=E5=8D=95=E8=AF=8D=E5=88=B0?= =?UTF-8?q?=E8=AF=8D=E5=85=B8=E4=B8=AD=20*=20lazy=20update:=20=E5=BD=93?= =?UTF-8?q?=E7=94=A8=E5=88=B0=E8=AF=8D=E5=85=B8=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E6=89=8D=E9=87=8D=E6=96=B0build=20*=20=E5=BD=93=E6=96=B0?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=9A=84=E8=AF=8D=E5=AF=BC=E8=87=B4=E8=AF=8D?= =?UTF-8?q?=E5=85=B8=E5=A4=A7=E5=B0=8F=E8=B6=85=E5=87=BA=E9=99=90=E5=88=B6?= =?UTF-8?q?=E6=97=B6=EF=BC=8C=E6=89=93=E5=8D=B0=E4=B8=80=E4=B8=AAwarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update Vocabulary: * More words can be added after the building. * Lazy update: rebuild automatically when vocab is used. * print warning when max size is reached --- fastNLP/core/vocabulary.py | 30 ++++++++++++++++++++++++++++-- test/core/test_vocabulary.py | 27 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 7b0ab614..ca6b4ebf 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -16,14 +16,35 @@ def isiterable(p_object): def check_build_vocab(func): + """A decorator to make sure the indexing is built before used. + + """ + def _wrapper(self, *args, **kwargs): - if self.word2idx is None: + if self.word2idx is None or self.rebuild is True: self.build_vocab() return func(self, *args, **kwargs) return _wrapper +def check_build_status(func): + """A decorator to check whether the vocabulary updates after the last build. + + """ + + def _wrapper(self, *args, **kwargs): + if self.rebuild is False: + self.rebuild = True + if self.max_size is not None and len(self.word_count) >= self.max_size: + print("[Warning] Vocabulary has reached the max size {} when calling {} method. " + "Adding more words may cause unexpected behaviour of Vocabulary. ".format( + self.max_size, func.__name__)) + return func(self, *args, **kwargs) + + return _wrapper + + class Vocabulary(object): """Use for word and index one to one mapping @@ -54,7 +75,9 @@ class Vocabulary(object): self.unknown_label = None self.word2idx = None self.idx2word = None + self.rebuild = True + @check_build_status def update(self, word_lst): """Add a list of words into the vocabulary. @@ -62,6 +85,7 @@ class Vocabulary(object): """ self.word_count.update(word_lst) + @check_build_status def add(self, word): """Add a single word into the vocabulary. @@ -69,6 +93,7 @@ class Vocabulary(object): """ self.word_count[word] += 1 + @check_build_status def add_word(self, word): """Add a single word into the vocabulary. @@ -76,6 +101,7 @@ class Vocabulary(object): """ self.add(word) + @check_build_status def add_word_lst(self, word_lst): """Add a list of words into the vocabulary. @@ -101,6 +127,7 @@ class Vocabulary(object): start_idx = len(self.word2idx) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() + self.rebuild = False def build_reverse_vocab(self): """Build 'index to word' dict based on 'word to index' dict. @@ -188,4 +215,3 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() - diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index e140b1aa..e453e935 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -59,3 +59,30 @@ class TestIndexing(unittest.TestCase): vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + + +class TestOther(unittest.TestCase): + def test_additional_update(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + + _ = vocab["well"] + self.assertEqual(vocab.rebuild, False) + + vocab.add("hahaha") + self.assertEqual(vocab.rebuild, True) + + _ = vocab["hahaha"] + self.assertEqual(vocab.rebuild, False) + self.assertTrue("hahaha" in vocab) + + def test_warning(self): + vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) + vocab.update(text) + self.assertEqual(vocab.rebuild, True) + print(len(vocab)) + self.assertEqual(vocab.rebuild, False) + + vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) + # this will print a warning + self.assertEqual(vocab.rebuild, True) From 3120cdd09a8f83378b59fd7e4f71da16ba4f7b12 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 17:23:25 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0embed=5Floader:=20*=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0fast=5Fload=5Fembedding=E6=96=B9=E6=B3=95?= =?UTF-8?q?=EF=BC=8C=E7=94=A8vocab=E7=9A=84=E8=AF=8D=E7=B4=A2=E5=BC=95pre-?= =?UTF-8?q?trained=E4=B8=AD=E7=9A=84embedding=20*=20=E5=A6=82=E6=9E=9Cvoca?= =?UTF-8?q?b=E6=9C=89=E8=AF=8D=E6=B2=A1=E5=87=BA=E7=8E=B0=E5=9C=A8pre-trai?= =?UTF-8?q?n=E4=B8=AD=EF=BC=8C=E4=BB=8E=E5=B7=B2=E6=9C=89embedding?= =?UTF-8?q?=E4=B8=AD=E6=AD=A3=E6=80=81=E9=87=87=E6=A0=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update embed_loader: * add fast_load_embedding method, to index pre-trained embedding with words in Vocab * If words in Vocab are not exist in pre-trained, sample them from normal distribution computed by current embeddings --- fastNLP/io/embed_loader.py | 77 ++++++++++++++++------- test/data_for_tests/glove.6B.50d_test.txt | 2 - test/io/test_embed_loader.py | 12 ++++ 3 files changed, 66 insertions(+), 25 deletions(-) create mode 100644 test/io/test_embed_loader.py diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 878ea1b6..6e557c2b 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -1,3 +1,4 @@ +import numpy as np import torch from fastNLP.core.vocabulary import Vocabulary @@ -26,7 +27,7 @@ class EmbedLoader(BaseLoader): emb = {} with open(emb_file, 'r', encoding='utf-8') as f: for line in f: - line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) + line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) if len(line) > 2: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb @@ -35,9 +36,9 @@ class EmbedLoader(BaseLoader): def _load_pretrain(emb_file, emb_type): """Read txt data from embedding file and convert to np.array as pre-trained embedding - :param emb_file: str, the pre-trained embedding file path - :param emb_type: str, the pre-trained embedding data format - :return dict: {str: np.array} + :param str emb_file: the pre-trained embedding file path + :param str emb_type: the pre-trained embedding data format + :return dict embedding: `{str: np.array}` """ if emb_type == 'glove': return EmbedLoader._load_glove(emb_file) @@ -45,38 +46,68 @@ class EmbedLoader(BaseLoader): raise Exception("embedding type {} not support yet".format(emb_type)) @staticmethod - def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): + def load_embedding(emb_dim, emb_file, emb_type, vocab): """Load the pre-trained embedding and combine with the given dictionary. - :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. - :param emb_file: str, the pre-trained embedding file path. - :param emb_type: str, the pre-trained embedding format, support glove now - :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding - :param emb_pkl: str, the embedding pickle file. + :param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. + :param str emb_file: the pre-trained embedding file path. + :param str emb_type: the pre-trained embedding format, support glove now + :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) vocab: input vocab or vocab built by pre-train - TODO: fragile code + """ - # If the embedding pickle exists, load it and return. - # if os.path.exists(emb_pkl): - # with open(emb_pkl, "rb") as f: - # embedding_tensor, vocab = _pickle.load(f) - # return embedding_tensor, vocab - # Otherwise, load the pre-trained embedding. pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) if vocab is None: # build vocabulary from pre-trained embedding vocab = Vocabulary() for w in pretrain.keys(): - vocab.update(w) + vocab.add(w) embedding_tensor = torch.randn(len(vocab), emb_dim) for w, v in pretrain.items(): if len(v.shape) > 1 or emb_dim != v.shape[0]: - raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) + raise ValueError( + "Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,))) if vocab.has_word(w): embedding_tensor[vocab[w]] = v - - # save and return the result - # with open(emb_pkl, "wb") as f: - # _pickle.dump((embedding_tensor, vocab), f) return embedding_tensor, vocab + + @staticmethod + def parse_glove_line(line): + line = list(filter(lambda w: len(w) > 0, line.strip().split(" "))) + if len(line) <= 2: + raise RuntimeError("something goes wrong in parsing glove embedding") + return line[0], torch.Tensor(list(map(float, line[1:]))) + + @staticmethod + def fast_load_embedding(emb_dim, emb_file, vocab): + """Fast load the pre-trained embedding and combine with the given dictionary. + This loading method uses line-by-line operation. + + :param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. + :param str emb_file: the pre-trained embedding file path. + :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding + :return numpy.ndarray embedding_matrix: + + """ + if vocab is None: + raise RuntimeError("You must provide a vocabulary.") + embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) + hit_flags = np.zeros(shape=(len(vocab),), dtype=int) + with open(emb_file, "r", encoding="utf-8") as f: + for line in f: + word, vector = EmbedLoader.parse_glove_line(line) + if word in vocab: + if len(vector.shape) > 1 or emb_dim != vector.shape[0]: + raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) + embedding_matrix[vocab[word]] = vector + hit_flags[vocab[word]] = 1 + + if np.sum(hit_flags) < len(vocab): + # some words from vocab are missing in pre-trained embedding + # we normally sample them + vocab_embed = embedding_matrix[np.where(hit_flags)] + mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T) + sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),)) + embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors + return embedding_matrix diff --git a/test/data_for_tests/glove.6B.50d_test.txt b/test/data_for_tests/glove.6B.50d_test.txt index cd71b26e..8b443cca 100644 --- a/test/data_for_tests/glove.6B.50d_test.txt +++ b/test/data_for_tests/glove.6B.50d_test.txt @@ -8,5 +8,3 @@ in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 " 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 - - diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py new file mode 100644 index 00000000..0a7c4fcf --- /dev/null +++ b/test/io/test_embed_loader.py @@ -0,0 +1,12 @@ +import unittest + +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.io.embed_loader import EmbedLoader + + +class TestEmbedLoader(unittest.TestCase): + def test_case(self): + vocab = Vocabulary() + vocab.update(["the", "in", "I", "to", "of", "hahaha"]) + embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab) + self.assertEqual(tuple(embedding.shape), (len(vocab), 50))