Browse Source

update static_embedding to make its property words_to_words is sequential

tags/v0.5.5
yh_cc 5 years ago
parent
commit
51e8c3c3b4
3 changed files with 35 additions and 7 deletions
  1. +14
    -6
      fastNLP/embeddings/static_embedding.py
  2. +3
    -1
      fastNLP/io/file_utils.py
  3. +18
    -0
      test/embeddings/test_static_embedding.py

+ 14
- 6
fastNLP/embeddings/static_embedding.py View File

@@ -175,6 +175,10 @@ class StaticEmbedding(TokenEmbedding):
sparse=False, _weight=embedding) sparse=False, _weight=embedding)
self._embed_size = self.embedding.weight.size(1) self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad self.requires_grad = requires_grad

@property
def weight(self):
return self.embedding.weight
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None):
""" """
@@ -223,7 +227,7 @@ class StaticEmbedding(TokenEmbedding):
else: else:
dim = len(parts) - 1 dim = len(parts) - 1
f.seek(0) f.seek(0)
matrix = {}
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word)
if vocab.padding: if vocab.padding:
matrix[vocab.padding_idx] = torch.zeros(dim) matrix[vocab.padding_idx] = torch.zeros(dim)
if vocab.unknown: if vocab.unknown:
@@ -270,11 +274,15 @@ class StaticEmbedding(TokenEmbedding):
else: else:
unknown_idx = vocab.unknown_idx unknown_idx = vocab.unknown_idx
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long())
for index, (index_in_vocab, vec) in enumerate(matrix.items()):
if vec is not None:
vectors[index] = vec
self.words_to_words[index_in_vocab] = index
index = 0
for word, index_in_vocab in vocab:
if index_in_vocab in matrix:
vec = matrix.get(index_in_vocab)
if vec is not None: # 使用找到的vector, 如果为None说明需要训练
vectors[index] = vec
self.words_to_words[index_in_vocab] = index
index += 1

return vectors return vectors
def forward(self, words): def forward(self, words):


+ 3
- 1
fastNLP/io/file_utils.py View File

@@ -432,7 +432,9 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
shutil.rmtree(cache_path) shutil.rmtree(cache_path)
os.close(fd) os.close(fd)
os.remove(temp_filename) os.remove(temp_filename)
if os.path.isdir(uncompress_temp_dir):
if uncompress_temp_dir is None:
pass
elif os.path.isdir(uncompress_temp_dir):
shutil.rmtree(uncompress_temp_dir) shutil.rmtree(uncompress_temp_dir)
elif os.path.isfile(uncompress_temp_dir): elif os.path.isfile(uncompress_temp_dir):
os.remove(uncompress_temp_dir) os.remove(uncompress_temp_dir)


+ 18
- 0
test/embeddings/test_static_embedding.py View File

@@ -89,6 +89,24 @@ class TestLoad(unittest.TestCase):
check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True)
check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed)


def test_sequential_index(self):
# 当不存在no_create_entry时,words_to_words应该是顺序的
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
'glove.6B.50d_test.txt')
for index,i in enumerate(embed.words_to_words):
assert index==i

embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
'glove.6B.50d_test.txt')

for word, index in vocab:
if word in embed_dict:
index = vocab.to_index(word)
v1 = embed(torch.LongTensor([index])).tolist()[0]
v2 = embed_dict[word]
for v1i, v2i in zip(v1, v2):
self.assertAlmostEqual(v1i, v2i, places=4)


def read_static_embed(fp): def read_static_embed(fp):
""" """


Loading…
Cancel
Save