|
|
@@ -175,6 +175,10 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
sparse=False, _weight=embedding) |
|
|
|
self._embed_size = self.embedding.weight.size(1) |
|
|
|
self.requires_grad = requires_grad |
|
|
|
|
|
|
|
@property |
|
|
|
def weight(self): |
|
|
|
return self.embedding.weight |
|
|
|
|
|
|
|
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): |
|
|
|
""" |
|
|
@@ -223,7 +227,7 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
else: |
|
|
|
dim = len(parts) - 1 |
|
|
|
f.seek(0) |
|
|
|
matrix = {} |
|
|
|
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word) |
|
|
|
if vocab.padding: |
|
|
|
matrix[vocab.padding_idx] = torch.zeros(dim) |
|
|
|
if vocab.unknown: |
|
|
@@ -270,11 +274,15 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
else: |
|
|
|
unknown_idx = vocab.unknown_idx |
|
|
|
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) |
|
|
|
for index, (index_in_vocab, vec) in enumerate(matrix.items()): |
|
|
|
if vec is not None: |
|
|
|
vectors[index] = vec |
|
|
|
self.words_to_words[index_in_vocab] = index |
|
|
|
|
|
|
|
index = 0 |
|
|
|
for word, index_in_vocab in vocab: |
|
|
|
if index_in_vocab in matrix: |
|
|
|
vec = matrix.get(index_in_vocab) |
|
|
|
if vec is not None: # 使用找到的vector, 如果为None说明需要训练 |
|
|
|
vectors[index] = vec |
|
|
|
self.words_to_words[index_in_vocab] = index |
|
|
|
index += 1 |
|
|
|
|
|
|
|
return vectors |
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|