Browse Source

修复StaticEmbedding的bug

tags/v0.4.10
yh 6 years ago
parent
commit
23e283c459
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      fastNLP/embeddings/static_embedding.py

+ 3
- 1
fastNLP/embeddings/static_embedding.py View File

@@ -122,6 +122,7 @@ class StaticEmbedding(TokenEmbedding):
unknown_idx = lowered_vocab.unknown_idx
else:
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
requires_grad=False)
for word, index in vocab:
@@ -129,7 +130,7 @@ class StaticEmbedding(TokenEmbedding):
word = word.lower()
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
continue # 如果不需要创建entry,已经默认unknown了
words_to_words[index] = words_to_words[lowered_vocab.to_index(word)]
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
self.words_to_words = words_to_words
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
else:
@@ -137,6 +138,7 @@ class StaticEmbedding(TokenEmbedding):
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
else:
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
if normalize:
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)



Loading…
Cancel
Save