|
@@ -121,28 +121,27 @@ class StaticEmbedding(TokenEmbedding): |
|
|
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) |
|
|
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) |
|
|
else: |
|
|
else: |
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) |
|
|
if lowered_vocab.unknown: |
|
|
if lowered_vocab.unknown: |
|
|
unknown_idx = lowered_vocab.unknown_idx |
|
|
unknown_idx = lowered_vocab.unknown_idx |
|
|
else: |
|
|
else: |
|
|
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow |
|
|
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow |
|
|
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) |
|
|
|
|
|
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), |
|
|
|
|
|
requires_grad=False) |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) |
|
|
|
|
|
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long() |
|
|
for word, index in vocab: |
|
|
for word, index in vocab: |
|
|
if word not in lowered_vocab: |
|
|
if word not in lowered_vocab: |
|
|
word = word.lower() |
|
|
word = word.lower() |
|
|
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): |
|
|
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): |
|
|
continue # 如果不需要创建entry,已经默认unknown了 |
|
|
continue # 如果不需要创建entry,已经默认unknown了 |
|
|
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] |
|
|
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] |
|
|
self.words_to_words = words_to_words |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', words_to_words) |
|
|
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index |
|
|
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index |
|
|
else: |
|
|
else: |
|
|
if model_path: |
|
|
if model_path: |
|
|
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) |
|
|
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) |
|
|
else: |
|
|
else: |
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) |
|
|
if not self.only_norm_found_vector and normalize: |
|
|
if not self.only_norm_found_vector and normalize: |
|
|
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) |
|
|
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) |
|
|
|
|
|
|
|
@@ -151,7 +150,7 @@ class StaticEmbedding(TokenEmbedding): |
|
|
index_in_truncated_vocab = truncated_words_to_words[i] |
|
|
index_in_truncated_vocab = truncated_words_to_words[i] |
|
|
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] |
|
|
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] |
|
|
del self.words_to_words |
|
|
del self.words_to_words |
|
|
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False) |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', truncated_words_to_words) |
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], |
|
|
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], |
|
|
padding_idx=vocab.padding_idx, |
|
|
padding_idx=vocab.padding_idx, |
|
@@ -273,8 +272,7 @@ class StaticEmbedding(TokenEmbedding): |
|
|
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() |
|
|
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() |
|
|
else: |
|
|
else: |
|
|
unknown_idx = vocab.unknown_idx |
|
|
unknown_idx = vocab.unknown_idx |
|
|
self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), |
|
|
|
|
|
requires_grad=False) |
|
|
|
|
|
|
|
|
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) |
|
|
|
|
|
|
|
|
for index, (index_in_vocab, vec) in enumerate(matrix.items()): |
|
|
for index, (index_in_vocab, vec) in enumerate(matrix.items()): |
|
|
if vec is not None: |
|
|
if vec is not None: |
|
|