|
@@ -291,9 +291,10 @@ class _WordBertModel(nn.Module): |
|
|
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size |
|
|
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size |
|
|
""" |
|
|
""" |
|
|
batch_size, max_word_len = words.size() |
|
|
batch_size, max_word_len = words.size() |
|
|
seq_len = words.ne(self._pad_index).sum(dim=-1) |
|
|
|
|
|
|
|
|
word_mask = words.ne(self._pad_index) |
|
|
|
|
|
seq_len = word_mask.sum(dim=-1) |
|
|
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len |
|
|
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len |
|
|
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) |
|
|
|
|
|
|
|
|
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask, 0).sum(dim=-1) |
|
|
max_word_piece_length = word_pieces_lengths.max().item() |
|
|
max_word_piece_length = word_pieces_lengths.max().item() |
|
|
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 |
|
|
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 |
|
|
if max_word_piece_length+2>self._max_position_embeddings: |
|
|
if max_word_piece_length+2>self._max_position_embeddings: |
|
@@ -319,8 +320,7 @@ class _WordBertModel(nn.Module): |
|
|
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: |
|
|
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: |
|
|
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] |
|
|
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] |
|
|
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) |
|
|
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) |
|
|
attn_masks[i, :len(word_pieces_i)+2].fill_(1) |
|
|
|
|
|
# TODO 截掉长度超过的部分。 |
|
|
|
|
|
|
|
|
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) |
|
|
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 |
|
|
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 |
|
|
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] |
|
|
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] |
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, |
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, |
|
|