diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index b907a278..fcfb3fec 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -175,6 +175,10 @@ class StaticEmbedding(TokenEmbedding): sparse=False, _weight=embedding) self._embed_size = self.embedding.weight.size(1) self.requires_grad = requires_grad + + @property + def weight(self): + return self.embedding.weight def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): """ @@ -223,7 +227,7 @@ class StaticEmbedding(TokenEmbedding): else: dim = len(parts) - 1 f.seek(0) - matrix = {} + matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word) if vocab.padding: matrix[vocab.padding_idx] = torch.zeros(dim) if vocab.unknown: @@ -270,11 +274,15 @@ class StaticEmbedding(TokenEmbedding): else: unknown_idx = vocab.unknown_idx self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) - for index, (index_in_vocab, vec) in enumerate(matrix.items()): - if vec is not None: - vectors[index] = vec - self.words_to_words[index_in_vocab] = index - + index = 0 + for word, index_in_vocab in vocab: + if index_in_vocab in matrix: + vec = matrix.get(index_in_vocab) + if vec is not None: # 使用找到的vector, 如果为None说明需要训练 + vectors[index] = vec + self.words_to_words[index_in_vocab] = index + index += 1 + return vectors def forward(self, words): diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index db4ccc45..5195ec74 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -432,7 +432,9 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: shutil.rmtree(cache_path) os.close(fd) os.remove(temp_filename) - if os.path.isdir(uncompress_temp_dir): + if uncompress_temp_dir is None: + pass + elif os.path.isdir(uncompress_temp_dir): shutil.rmtree(uncompress_temp_dir) elif os.path.isfile(uncompress_temp_dir): os.remove(uncompress_temp_dir) diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 61b7f2ed..755bb5cd 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -89,6 +89,24 @@ class TestLoad(unittest.TestCase): check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) + def test_sequential_index(self): + # 当不存在no_create_entry时,words_to_words应该是顺序的 + vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt') + for index,i in enumerate(embed.words_to_words): + assert index==i + + embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt') + + for word, index in vocab: + if word in embed_dict: + index = vocab.to_index(word) + v1 = embed(torch.LongTensor([index])).tolist()[0] + v2 = embed_dict[word] + for v1i, v2i in zip(v1, v2): + self.assertAlmostEqual(v1i, v2i, places=4) def read_static_embed(fp): """