From 3270b8b48bcc7ce2b015ded9554bad548c3dd344 Mon Sep 17 00:00:00 2001 From: stratoes <358651588@qq.com> Date: Tue, 20 Oct 2020 00:20:32 +0800 Subject: [PATCH] token_type_id_rev (#329) --- fastNLP/embeddings/bert_embedding.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index e3b91934..ec2ba26b 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -294,8 +294,7 @@ class BertWordPieceEncoder(nn.Module): sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) - if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 - token_type_ids = token_type_ids.eq(0).long() + token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 word_pieces = self.drop_word(word_pieces) outputs = self.model(word_pieces, token_type_ids) @@ -465,8 +464,7 @@ class _BertWordModel(nn.Module): sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) - if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 - token_type_ids = token_type_ids.eq(0).long() + token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 else: token_type_ids = torch.zeros_like(word_pieces) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算