Browse Source

使用更pytorch的方式处理embedding中的parameter

tags/v0.4.10
yh 5 years ago
parent
commit
19bbaf11b6
4 changed files with 16 additions and 21 deletions
  1. +1
    -1
      fastNLP/embeddings/bert_embedding.py
  2. +6
    -8
      fastNLP/embeddings/char_embedding.py
  3. +2
    -3
      fastNLP/embeddings/elmo_embedding.py
  4. +7
    -9
      fastNLP/embeddings/static_embedding.py

+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -345,7 +345,7 @@ class _WordBertModel(nn.Module):
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces) self.word_to_wordpieces = np.array(word_to_wordpieces)
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False)
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
print("Successfully generate word pieces.") print("Successfully generate word pieces.")


def forward(self, words): def forward(self, words):


+ 6
- 8
fastNLP/embeddings/char_embedding.py View File

@@ -82,10 +82,9 @@ class CNNCharEmbedding(TokenEmbedding):
print(f"In total, there are {len(self.char_vocab)} distinct characters.") print(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
max_word_len = max(map(lambda x: len(x[0]), vocab)) max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len),
fill_value=self.char_pad_index, dtype=torch.long),
requires_grad=False)
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
fill_value=self.char_pad_index, dtype=torch.long))
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
for word, index in vocab: for word, index in vocab:
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed
self.words_to_chars_embedding[index, :len(word)] = \ self.words_to_chars_embedding[index, :len(word)] = \
@@ -235,10 +234,9 @@ class LSTMCharEmbedding(TokenEmbedding):
print(f"In total, there are {len(self.char_vocab)} distinct characters.") print(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
fill_value=self.char_pad_index, dtype=torch.long),
requires_grad=False)
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len),
fill_value=self.char_pad_index, dtype=torch.long))
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
for word, index in vocab: for word, index in vocab:
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
self.words_to_chars_embedding[index, :len(word)] = \ self.words_to_chars_embedding[index, :len(word)] = \


+ 2
- 3
fastNLP/embeddings/elmo_embedding.py View File

@@ -240,10 +240,9 @@ class _ElmoModel(nn.Module):
# 生成words到chars的映射 # 生成words到chars的映射
max_chars = config['char_cnn']['max_characters_per_token'] max_chars = config['char_cnn']['max_characters_per_token']


self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars),
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars),
fill_value=len(char_vocab), fill_value=len(char_vocab),
dtype=torch.long),
requires_grad=False)
dtype=torch.long))
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]:
if len(word) + 2 > max_chars: if len(word) + 2 > max_chars:
word = word[:max_chars - 2] word = word[:max_chars - 2]


+ 7
- 9
fastNLP/embeddings/static_embedding.py View File

@@ -121,28 +121,27 @@ class StaticEmbedding(TokenEmbedding):
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
else: else:
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
if lowered_vocab.unknown: if lowered_vocab.unknown:
unknown_idx = lowered_vocab.unknown_idx unknown_idx = lowered_vocab.unknown_idx
else: else:
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
requires_grad=False)
self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long()
for word, index in vocab: for word, index in vocab:
if word not in lowered_vocab: if word not in lowered_vocab:
word = word.lower() word = word.lower()
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
continue # 如果不需要创建entry,已经默认unknown了 continue # 如果不需要创建entry,已经默认unknown了
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
self.words_to_words = words_to_words
self.register_buffer('words_to_words', words_to_words)
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
else: else:
if model_path: if model_path:
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
else: else:
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
if not self.only_norm_found_vector and normalize: if not self.only_norm_found_vector and normalize:
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)


@@ -151,7 +150,7 @@ class StaticEmbedding(TokenEmbedding):
index_in_truncated_vocab = truncated_words_to_words[i] index_in_truncated_vocab = truncated_words_to_words[i]
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
del self.words_to_words del self.words_to_words
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False)
self.register_buffer('words_to_words', truncated_words_to_words)


self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx, padding_idx=vocab.padding_idx,
@@ -273,8 +272,7 @@ class StaticEmbedding(TokenEmbedding):
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
else: else:
unknown_idx = vocab.unknown_idx unknown_idx = vocab.unknown_idx
self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(),
requires_grad=False)
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long())


for index, (index_in_vocab, vec) in enumerate(matrix.items()): for index, (index_in_vocab, vec) in enumerate(matrix.items()):
if vec is not None: if vec is not None:


Loading…
Cancel
Save