@@ -175,6 +175,10 @@ class StaticEmbedding(TokenEmbedding): | |||||
sparse=False, _weight=embedding) | sparse=False, _weight=embedding) | ||||
self._embed_size = self.embedding.weight.size(1) | self._embed_size = self.embedding.weight.size(1) | ||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
@property | |||||
def weight(self): | |||||
return self.embedding.weight | |||||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | ||||
""" | """ | ||||
@@ -223,7 +227,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
dim = len(parts) - 1 | dim = len(parts) - 1 | ||||
f.seek(0) | f.seek(0) | ||||
matrix = {} | |||||
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word) | |||||
if vocab.padding: | if vocab.padding: | ||||
matrix[vocab.padding_idx] = torch.zeros(dim) | matrix[vocab.padding_idx] = torch.zeros(dim) | ||||
if vocab.unknown: | if vocab.unknown: | ||||
@@ -270,11 +274,15 @@ class StaticEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
unknown_idx = vocab.unknown_idx | unknown_idx = vocab.unknown_idx | ||||
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) | self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) | ||||
for index, (index_in_vocab, vec) in enumerate(matrix.items()): | |||||
if vec is not None: | |||||
vectors[index] = vec | |||||
self.words_to_words[index_in_vocab] = index | |||||
index = 0 | |||||
for word, index_in_vocab in vocab: | |||||
if index_in_vocab in matrix: | |||||
vec = matrix.get(index_in_vocab) | |||||
if vec is not None: # 使用找到的vector, 如果为None说明需要训练 | |||||
vectors[index] = vec | |||||
self.words_to_words[index_in_vocab] = index | |||||
index += 1 | |||||
return vectors | return vectors | ||||
def forward(self, words): | def forward(self, words): | ||||
@@ -432,7 +432,9 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
shutil.rmtree(cache_path) | shutil.rmtree(cache_path) | ||||
os.close(fd) | os.close(fd) | ||||
os.remove(temp_filename) | os.remove(temp_filename) | ||||
if os.path.isdir(uncompress_temp_dir): | |||||
if uncompress_temp_dir is None: | |||||
pass | |||||
elif os.path.isdir(uncompress_temp_dir): | |||||
shutil.rmtree(uncompress_temp_dir) | shutil.rmtree(uncompress_temp_dir) | ||||
elif os.path.isfile(uncompress_temp_dir): | elif os.path.isfile(uncompress_temp_dir): | ||||
os.remove(uncompress_temp_dir) | os.remove(uncompress_temp_dir) | ||||
@@ -89,6 +89,24 @@ class TestLoad(unittest.TestCase): | |||||
check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | ||||
check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | ||||
def test_sequential_index(self): | |||||
# 当不存在no_create_entry时,words_to_words应该是顺序的 | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | |||||
for index,i in enumerate(embed.words_to_words): | |||||
assert index==i | |||||
embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | |||||
for word, index in vocab: | |||||
if word in embed_dict: | |||||
index = vocab.to_index(word) | |||||
v1 = embed(torch.LongTensor([index])).tolist()[0] | |||||
v2 = embed_dict[word] | |||||
for v1i, v2i in zip(v1, v2): | |||||
self.assertAlmostEqual(v1i, v2i, places=4) | |||||
def read_static_embed(fp): | def read_static_embed(fp): | ||||
""" | """ | ||||