Browse Source

修复BertEmbedding在include_cls_sep=True时,SEP位置的值不正确的问题

tags/v0.5.5
yh_cc 5 years ago
parent
commit
ace8aafdb2
6 changed files with 39 additions and 7 deletions
  1. +4
    -4
      fastNLP/embeddings/bert_embedding.py
  2. +1
    -1
      fastNLP/modules/encoder/bert.py
  3. +1
    -1
      test/data_for_tests/embedding/small_bert/config.json
  4. BIN
      test/data_for_tests/embedding/small_bert/small_pytorch_model.bin
  5. +1
    -0
      test/data_for_tests/embedding/small_bert/vocab.txt
  6. +32
    -1
      test/embeddings/test_bert_embedding.py

+ 4
- 4
fastNLP/embeddings/bert_embedding.py View File

@@ -224,9 +224,9 @@ class BertWordPieceEncoder(nn.Module):
第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
with torch.no_grad():
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
if token_type_ids is None:
if token_type_ids is None:
with torch.no_grad():
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
@@ -462,7 +462,7 @@ class _WordBertModel(nn.Module):
outputs[l_index, :, 0] = pooled_cls
else:
outputs[l_index, :, 0] = output_layer[:, 0]
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]

# 3. 最终的embedding结果
return outputs

+ 1
- 1
fastNLP/modules/encoder/bert.py View File

@@ -1011,7 +1011,7 @@ class _WordPieceBertModel(nn.Module):
if word_pieces[0] != self._cls_index:
word_pieces.insert(0, self._cls_index)
if word_pieces[-1] != self._sep_index:
word_pieces.insert(-1, self._sep_index)
word_pieces.append(self._sep_index)
return word_pieces

for index, dataset in enumerate(datasets):


+ 1
- 1
test/data_for_tests/embedding/small_bert/config.json View File

@@ -9,5 +9,5 @@
"num_attention_heads": 4,
"num_hidden_layers": 2,
"type_vocab_size": 2,
"vocab_size": 20
"vocab_size": 21
}

BIN
test/data_for_tests/embedding/small_bert/small_pytorch_model.bin View File


+ 1
- 0
test/data_for_tests/embedding/small_bert/vocab.txt View File

@@ -18,3 +18,4 @@ for
the
whole
text
##a

+ 32
- 1
test/embeddings/test_bert_embedding.py View File

@@ -3,6 +3,8 @@ from fastNLP import Vocabulary
from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder
import torch
import os
from fastNLP import DataSet


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
@@ -45,12 +47,41 @@ class TestBertEmbedding(unittest.TestCase):
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))

# 自动截断而不报错
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
only_use_pretrain_bpe=True, auto_truncate=True)
words = torch.LongTensor([[2, 3, 4, 0]*129])
result = embed(words)
self.assertEqual(result.size(), (1, 516, 16))


class TestBertWordPieceEncoder(unittest.TestCase):
def test_bert_word_piece_encoder(self):
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
from fastNLP import DataSet
ds = DataSet({'words': ["this is a test . [SEP]".split()]})
embed.index_datasets(ds, field_name='words')
self.assertTrue(ds.has_field('word_pieces'))
result = embed(torch.LongTensor([[1,2,3,4]]))

def test_bert_embed_eq_bert_piece_encoder(self):
ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]})
encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert')
encoder.eval()
encoder.index_datasets(ds, field_name='words')
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
word_pieces_res = encoder(word_pieces)

vocab = Vocabulary()
vocab.from_dataset(ds, field_name='words')
vocab.index_dataset(ds, field_name='words', new_field_name='words')
ds.set_input('words')
words = torch.LongTensor(ds['words'].get([0, 1]))
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
pool_method='first', include_cls_sep=True, pooled_cls=False)
embed.eval()
words_res = embed(words)

# 检查word piece什么的是正常work的
self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0)
self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0)
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0)

Loading…
Cancel
Save