Browse Source

修复BertEmbedding的bug

tags/v0.4.10
yh 6 years ago
parent
commit
2098a81f2f
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      fastNLP/embeddings/bert_embedding.py

+ 4
- 4
fastNLP/embeddings/bert_embedding.py View File

@@ -291,9 +291,10 @@ class _WordBertModel(nn.Module):
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
"""
batch_size, max_word_len = words.size()
seq_len = words.ne(self._pad_index).sum(dim=-1)
word_mask = words.ne(self._pad_index)
seq_len = word_mask.sum(dim=-1)
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1)
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1)
max_word_piece_length = word_pieces_lengths.max().item()
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度
if max_word_piece_length+2>self._max_position_embeddings:
@@ -319,8 +320,7 @@ class _WordBertModel(nn.Module):
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2:
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2]
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :len(word_pieces_i)+2].fill_(1)
# TODO 截掉长度超过的部分。
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,


Loading…
Cancel
Save