|
|
@@ -290,45 +290,45 @@ class _WordBertModel(nn.Module): |
|
|
|
:param words: torch.LongTensor, batch_size x max_len |
|
|
|
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size |
|
|
|
""" |
|
|
|
batch_size, max_word_len = words.size() |
|
|
|
word_mask = words.ne(self._word_pad_index) # 为1的地方有word |
|
|
|
seq_len = word_mask.sum(dim=-1) |
|
|
|
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len |
|
|
|
word_pieces_lengths = batch_word_pieces_length.masked_fill(word_mask.eq(0), 0).sum(dim=-1) # batch_size |
|
|
|
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) |
|
|
|
if word_piece_length+2>self._max_position_embeddings: |
|
|
|
if self.auto_truncate: |
|
|
|
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, |
|
|
|
self._max_position_embeddings-2) |
|
|
|
else: |
|
|
|
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " |
|
|
|
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") |
|
|
|
|
|
|
|
# +2是由于需要加入[CLS]与[SEP] |
|
|
|
word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), |
|
|
|
fill_value=self._wordpiece_pad_index) |
|
|
|
attn_masks = torch.zeros_like(word_pieces) |
|
|
|
# 1. 获取words的word_pieces的id,以及对应的span范围 |
|
|
|
word_indexes = words.tolist() |
|
|
|
for i in range(batch_size): |
|
|
|
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) |
|
|
|
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: |
|
|
|
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] |
|
|
|
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) |
|
|
|
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) |
|
|
|
# 添加[cls]和[sep] |
|
|
|
word_pieces[:, 0].fill_(self._cls_index) |
|
|
|
batch_indexes = torch.arange(batch_size).to(words) |
|
|
|
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index |
|
|
|
if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids |
|
|
|
with torch.no_grad(): |
|
|
|
with torch.no_grad(): |
|
|
|
batch_size, max_word_len = words.size() |
|
|
|
word_mask = words.ne(self._word_pad_index) # 为1的地方有word |
|
|
|
seq_len = word_mask.sum(dim=-1) |
|
|
|
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), 0) # batch_size x max_len |
|
|
|
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size |
|
|
|
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) |
|
|
|
if word_piece_length+2>self._max_position_embeddings: |
|
|
|
if self.auto_truncate: |
|
|
|
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, |
|
|
|
self._max_position_embeddings-2) |
|
|
|
else: |
|
|
|
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " |
|
|
|
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") |
|
|
|
|
|
|
|
# +2是由于需要加入[CLS]与[SEP] |
|
|
|
word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), |
|
|
|
fill_value=self._wordpiece_pad_index) |
|
|
|
attn_masks = torch.zeros_like(word_pieces) |
|
|
|
# 1. 获取words的word_pieces的id,以及对应的span范围 |
|
|
|
word_indexes = words.cpu().numpy() |
|
|
|
for i in range(batch_size): |
|
|
|
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) |
|
|
|
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: |
|
|
|
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] |
|
|
|
word_pieces[i, 1:word_pieces_lengths[i]+1] = torch.LongTensor(word_pieces_i) |
|
|
|
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) |
|
|
|
# 添加[cls]和[sep] |
|
|
|
word_pieces[:, 0].fill_(self._cls_index) |
|
|
|
batch_indexes = torch.arange(batch_size).to(words) |
|
|
|
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index |
|
|
|
if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids |
|
|
|
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len |
|
|
|
sep_mask_cumsum = sep_mask.flip(dim=-1).cumsum(dim=-1).flip(dim=-1) |
|
|
|
token_type_ids = sep_mask_cumsum.fmod(2) |
|
|
|
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
token_type_ids = token_type_ids.eq(0).float() |
|
|
|
else: |
|
|
|
token_type_ids = torch.zeros_like(word_pieces) |
|
|
|
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
|
|
|
token_type_ids = sep_mask_cumsum.fmod(2) |
|
|
|
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
token_type_ids = token_type_ids.eq(0).float() |
|
|
|
else: |
|
|
|
token_type_ids = torch.zeros_like(word_pieces) |
|
|
|
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 |
|
|
|
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] |
|
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, |
|
|
|