Browse Source

修复BertEmbedding的bug

tags/v0.4.10
yh 6 years ago
parent
commit
1b661e907a
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/embeddings/bert_embedding.py

+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -294,7 +294,7 @@ class _WordBertModel(nn.Module):
word_mask = words.ne(self._pad_index)
seq_len = word_mask.sum(dim=-1)
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1)
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1)
max_word_piece_length = word_pieces_lengths.max().item()
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度
if max_word_piece_length+2>self._max_position_embeddings:


Loading…
Cancel
Save