Browse Source

token_type_id_rev (#329)

tags/v0.6.0
stratoes GitHub 4 years ago
parent
commit
3270b8b48b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 4 deletions
  1. +2
    -4
      fastNLP/embeddings/bert_embedding.py

+ 2
- 4
fastNLP/embeddings/bert_embedding.py View File

@@ -294,8 +294,7 @@ class BertWordPieceEncoder(nn.Module):
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0

word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids)
@@ -465,8 +464,7 @@ class _BertWordModel(nn.Module):
sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
else:
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算


Loading…
Cancel
Save