Browse Source

修复BertEmbedding的bug

tags/v0.4.10
yh 6 years ago
parent
commit
b0c50f7299
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/embeddings/bert_embedding.py

+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -311,7 +311,6 @@ class _WordBertModel(nn.Module):
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index)
word_pieces[:, 0].fill_(self._cls_index)
batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index
attn_masks = torch.zeros_like(word_pieces)
# 1. 获取words的word_pieces的id,以及对应的span范围
word_indexes = words.tolist()
@@ -320,6 +319,7 @@ class _WordBertModel(nn.Module):
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2:
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2]
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
word_pieces[i, len(word_pieces_i)+1] = self._sep_index # 补上sep
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]


Loading…
Cancel
Save