|
|
@@ -278,7 +278,7 @@ class _WordBertModel(nn.Module): |
|
|
|
print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) |
|
|
|
self._cls_index = self.tokenzier.vocab['[CLS]'] |
|
|
|
self._sep_index = self.tokenzier.vocab['[SEP]'] |
|
|
|
self._pad_index = vocab.padding_idx |
|
|
|
self._word_pad_index = vocab.padding_idx |
|
|
|
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece |
|
|
|
self.word_to_wordpieces = np.array(word_to_wordpieces) |
|
|
|
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) |
|
|
@@ -291,23 +291,22 @@ class _WordBertModel(nn.Module): |
|
|
|
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size |
|
|
|
""" |
|
|
|
batch_size, max_word_len = words.size() |
|
|
|
word_mask = words.ne(self._pad_index) |
|
|
|
word_mask = words.ne(self._word_pad_index) # 为1的地方有word |
|
|
|
seq_len = word_mask.sum(dim=-1) |
|
|
|
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len |
|
|
|
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) |
|
|
|
max_word_piece_length = word_pieces_lengths.max().item() |
|
|
|
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 |
|
|
|
if max_word_piece_length+2>self._max_position_embeddings: |
|
|
|
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) # batch_size |
|
|
|
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) |
|
|
|
if word_piece_length+2>self._max_position_embeddings: |
|
|
|
if self.auto_truncate: |
|
|
|
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, |
|
|
|
self._max_position_embeddings-2) |
|
|
|
max_word_piece_length = self._max_position_embeddings-2 |
|
|
|
else: |
|
|
|
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " |
|
|
|
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") |
|
|
|
|
|
|
|
# +2是由于需要加入[CLS]与[SEP] |
|
|
|
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) |
|
|
|
word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), |
|
|
|
fill_value=self._wordpiece_pad_index) |
|
|
|
attn_masks = torch.zeros_like(word_pieces) |
|
|
|
# 1. 获取words的word_pieces的id,以及对应的span范围 |
|
|
|
word_indexes = words.tolist() |
|
|
@@ -325,7 +324,7 @@ class _WordBertModel(nn.Module): |
|
|
|
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] |
|
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, |
|
|
|
output_all_encoded_layers=True) |
|
|
|
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size |
|
|
|
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size |
|
|
|
|
|
|
|
if self.include_cls_sep: |
|
|
|
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, |
|
|
@@ -339,9 +338,10 @@ class _WordBertModel(nn.Module): |
|
|
|
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len |
|
|
|
for l_index, l in enumerate(self.layers): |
|
|
|
output_layer = bert_outputs[l] |
|
|
|
if real_max_word_piece_length > max_word_piece_length: # 如果实际上是截取出来的 |
|
|
|
real_word_piece_length = output_layer.size(1) - 2 |
|
|
|
if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 |
|
|
|
paddings = output_layer.new_zeros(batch_size, |
|
|
|
real_max_word_piece_length-max_word_piece_length, |
|
|
|
word_piece_length-real_word_piece_length, |
|
|
|
output_layer.size(2)) |
|
|
|
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() |
|
|
|
# 从word_piece collapse到word的表示 |
|
|
|