@@ -18,6 +18,15 @@ def isiterable(p_object): | |||||
return False | return False | ||||
return True | return True | ||||
def check_build_vocab(func): | |||||
def _wrapper(self, *args, **kwargs): | |||||
if self.word2idx is None: | |||||
self.build_vocab() | |||||
self.build_reverse_vocab() | |||||
elif self.idx2word is None: | |||||
self.build_reverse_vocab() | |||||
return func(self, *args, **kwargs) | |||||
return _wrapper | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
"""Use for word and index one to one mapping | """Use for word and index one to one mapping | ||||
@@ -30,30 +39,23 @@ class Vocabulary(object): | |||||
vocab["word"] | vocab["word"] | ||||
vocab.to_word(5) | vocab.to_word(5) | ||||
""" | """ | ||||
def __init__(self, need_default=True): | |||||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||||
""" | """ | ||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | ||||
:param int max_size: set the max number of words in Vocabulary. Default: None | |||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||||
""" | """ | ||||
if need_default: | |||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.word2idx = {} | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.max_size = max_size | |||||
self.min_freq = min_freq | |||||
self.word_count = {} | |||||
self.has_default = need_default | self.has_default = need_default | ||||
self.word2idx = None | |||||
self.idx2word = None | self.idx2word = None | ||||
def __len__(self): | |||||
return len(self.word2idx) | |||||
def update(self, word): | def update(self, word): | ||||
"""add word or list of words into Vocabulary | """add word or list of words into Vocabulary | ||||
:param word: a list of string or a single string | :param word: a list of string or a single string | ||||
""" | """ | ||||
if not isinstance(word, str) and isiterable(word): | if not isinstance(word, str) and isiterable(word): | ||||
@@ -61,12 +63,48 @@ class Vocabulary(object): | |||||
for w in word: | for w in word: | ||||
self.update(w) | self.update(w) | ||||
else: | else: | ||||
# it's a word to be added | |||||
if word not in self.word2idx: | |||||
self.word2idx[word] = len(self) | |||||
if self.idx2word is not None: | |||||
self.idx2word = None | |||||
# it's a word to be added | |||||
if word not in self.word_count: | |||||
self.word_count[word] = 1 | |||||
else: | |||||
self.word_count[word] += 1 | |||||
self.word2idx = None | |||||
def build_vocab(self): | |||||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||||
""" | |||||
if self.has_default: | |||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.word2idx = {} | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||||
if self.min_freq is not None: | |||||
words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) | |||||
if self.max_size is not None and len(words) > self.max_size: | |||||
words = words[:self.max_size] | |||||
for w, _ in words: | |||||
self.word2idx[w] = len(self.word2idx) | |||||
def build_reverse_vocab(self): | |||||
"""build 'index to word' dict based on 'word to index' dict | |||||
""" | |||||
self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
@check_build_vocab | |||||
def __len__(self): | |||||
return len(self.word2idx) | |||||
@check_build_vocab | |||||
def has_word(self, w): | |||||
return w in self.word2idx | |||||
@check_build_vocab | |||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
"""To support usage like:: | """To support usage like:: | ||||
@@ -74,32 +112,33 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
else: | |||||
elif self.has_default: | |||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL] | return self.word2idx[DEFAULT_UNKNOWN_LABEL] | ||||
else: | |||||
raise ValueError("word {} not in vocabulary".format(w)) | |||||
@check_build_vocab | |||||
def to_index(self, w): | def to_index(self, w): | ||||
""" like to_index(w) function, turn a word to the index | """ like to_index(w) function, turn a word to the index | ||||
if w is not in Vocabulary, return the unknown label | if w is not in Vocabulary, return the unknown label | ||||
:param str w: | :param str w: | ||||
""" | """ | ||||
return self[w] | return self[w] | ||||
@check_build_vocab | |||||
def unknown_idx(self): | def unknown_idx(self): | ||||
if self.unknown_label is None: | if self.unknown_label is None: | ||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
@check_build_vocab | |||||
def padding_idx(self): | def padding_idx(self): | ||||
if self.padding_label is None: | if self.padding_label is None: | ||||
return None | return None | ||||
return self.word2idx[self.padding_label] | return self.word2idx[self.padding_label] | ||||
def build_reverse_vocab(self): | |||||
"""build 'index to word' dict based on 'word to index' dict | |||||
""" | |||||
self.idx2word = {self.word2idx[w]: w for w in self.word2idx} | |||||
@check_build_vocab | |||||
def to_word(self, idx): | def to_word(self, idx): | ||||
"""given a word's index, return the word itself | """given a word's index, return the word itself | ||||
@@ -4,7 +4,7 @@ import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
"""docstring for EmbedLoader""" | """docstring for EmbedLoader""" | ||||
@@ -13,18 +13,50 @@ class EmbedLoader(BaseLoader): | |||||
super(EmbedLoader, self).__init__(data_path) | super(EmbedLoader, self).__init__(data_path) | ||||
@staticmethod | @staticmethod | ||||
def load_embedding(emb_dim, emb_file, word_dict, emb_pkl): | |||||
def _load_glove(emb_file): | |||||
"""Read file as a glove embedding | |||||
file format: | |||||
embeddings are split by line, | |||||
for one embedding, word and numbers split by space | |||||
Example:: | |||||
word_1 float_1 float_2 ... float_emb_dim | |||||
word_2 float_1 float_2 ... float_emb_dim | |||||
... | |||||
""" | |||||
emb = {} | |||||
with open(emb_file, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | |||||
if len(line) > 0: | |||||
emb[line[0]] = np.array(list(map(float, line[1:]))) | |||||
return emb | |||||
@staticmethod | |||||
def _load_pretrain(emb_file, emb_type): | |||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | |||||
:param emb_file: str, the pre-trained embedding file path | |||||
:param emb_type: str, the pre-trained embedding data format | |||||
:return dict: {str: np.array} | |||||
""" | |||||
if emb_type == 'glove': | |||||
return EmbedLoader._load_glove(emb_file) | |||||
else: | |||||
raise Exception("embedding type {} not support yet".format(emb_type)) | |||||
@staticmethod | |||||
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | """Load the pre-trained embedding and combine with the given dictionary. | ||||
:param emb_file: str, the pre-trained embedding. | |||||
The embedding file should have the following format: | |||||
Each line is a word embedding, where a word string is followed by multiple floats. | |||||
Floats are separated by space. The word and the first float are separated by space. | |||||
:param word_dict: dict, a mapping from word to index. | |||||
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | ||||
:param emb_file: str, the pre-trained embedding file path. | |||||
:param emb_type: str, the pre-trained embedding format, support glove now | |||||
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:param emb_pkl: str, the embedding pickle file. | :param emb_pkl: str, the embedding pickle file. | ||||
:return embedding_np: numpy array of shape (len(word_dict), emb_dim) | :return embedding_np: numpy array of shape (len(word_dict), emb_dim) | ||||
vocab: input vocab or vocab built by pre-train | |||||
TODO: fragile code | TODO: fragile code | ||||
""" | """ | ||||
# If the embedding pickle exists, load it and return. | # If the embedding pickle exists, load it and return. | ||||
@@ -33,18 +65,20 @@ class EmbedLoader(BaseLoader): | |||||
embedding_np = _pickle.load(f) | embedding_np = _pickle.load(f) | ||||
return embedding_np | return embedding_np | ||||
# Otherwise, load the pre-trained embedding. | # Otherwise, load the pre-trained embedding. | ||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
# begin with a random embedding | |||||
embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) | |||||
for line in f: | |||||
line = line.strip().split() | |||||
if len(line) != emb_dim + 1: | |||||
# skip this line if two embedding dimension not match | |||||
continue | |||||
if line[0] in word_dict: | |||||
# find the word and replace its embedding with a pre-trained one | |||||
embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] | |||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||||
if vocab is None: | |||||
# build vocabulary from pre-trained embedding | |||||
vocab = Vocabulary() | |||||
for w in pretrain.keys(): | |||||
vocab.update(w) | |||||
embedding_np = np.random.uniform(-1, 1, size=(len(vocab), emb_dim)) | |||||
for w, v in pretrain.items(): | |||||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||||
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||||
if vocab.has_word(w): | |||||
embedding_np[vocab[w]] = v | |||||
# save and return the result | # save and return the result | ||||
with open(emb_pkl, "wb") as f: | with open(emb_pkl, "wb") as f: | ||||
_pickle.dump(embedding_np, f) | _pickle.dump(embedding_np, f) | ||||
return embedding_np | |||||
return embedding_np, vocab |
@@ -0,0 +1,12 @@ | |||||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | |||||
, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392 | |||||
. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216 | |||||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | |||||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | |||||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | |||||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | |||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | |||||
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | |||||
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | |||||
@@ -0,0 +1,33 @@ | |||||
import unittest | |||||
import os | |||||
import torch | |||||
from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class TestEmbedLoader(unittest.TestCase): | |||||
glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||||
pkl_path = './save' | |||||
raw_texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"ha ha", | |||||
"I am a good boy .", | |||||
"This is the most beautiful girl ." | |||||
] | |||||
texts = [text.strip().split() for text in raw_texts] | |||||
vocab = Vocabulary() | |||||
vocab.update(texts) | |||||
def test1(self): | |||||
emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||||
self.assertTrue(emb.shape[1] == 50) | |||||
os.remove(self.pkl_path) | |||||
def test2(self): | |||||
try: | |||||
_ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.fail(msg="load dismatch embedding") | |||||
except ValueError: | |||||
pass |