|
|
@@ -294,8 +294,7 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len |
|
|
|
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
|
|
|
token_type_ids = sep_mask_cumsum.fmod(2) |
|
|
|
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
token_type_ids = token_type_ids.eq(0).long() |
|
|
|
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
|
|
|
|
word_pieces = self.drop_word(word_pieces) |
|
|
|
outputs = self.model(word_pieces, token_type_ids) |
|
|
@@ -465,8 +464,7 @@ class _BertWordModel(nn.Module): |
|
|
|
sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len |
|
|
|
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
|
|
|
token_type_ids = sep_mask_cumsum.fmod(2) |
|
|
|
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
token_type_ids = token_type_ids.eq(0).long() |
|
|
|
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 |
|
|
|
else: |
|
|
|
token_type_ids = torch.zeros_like(word_pieces) |
|
|
|
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 |
|
|
|