diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 3bd448aa..660e803e 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -224,9 +224,9 @@ class BertWordPieceEncoder(nn.Module): 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。 :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) """ - with torch.no_grad(): - sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len - if token_type_ids is None: + if token_type_ids is None: + with torch.no_grad(): + sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 @@ -462,7 +462,7 @@ class _WordBertModel(nn.Module): outputs[l_index, :, 0] = pooled_cls else: outputs[l_index, :, 0] = output_layer[:, 0] - outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] + outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift] # 3. 最终的embedding结果 return outputs diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 3496c5f6..4523163b 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -1011,7 +1011,7 @@ class _WordPieceBertModel(nn.Module): if word_pieces[0] != self._cls_index: word_pieces.insert(0, self._cls_index) if word_pieces[-1] != self._sep_index: - word_pieces.insert(-1, self._sep_index) + word_pieces.append(self._sep_index) return word_pieces for index, dataset in enumerate(datasets): diff --git a/test/data_for_tests/embedding/small_bert/config.json b/test/data_for_tests/embedding/small_bert/config.json index 3e516872..da4cda35 100644 --- a/test/data_for_tests/embedding/small_bert/config.json +++ b/test/data_for_tests/embedding/small_bert/config.json @@ -9,5 +9,5 @@ "num_attention_heads": 4, "num_hidden_layers": 2, "type_vocab_size": 2, - "vocab_size": 20 + "vocab_size": 21 } \ No newline at end of file diff --git a/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin b/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin index fe968fb5..a0811def 100644 Binary files a/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin and b/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin differ diff --git a/test/data_for_tests/embedding/small_bert/vocab.txt b/test/data_for_tests/embedding/small_bert/vocab.txt index 565e67af..5c873094 100644 --- a/test/data_for_tests/embedding/small_bert/vocab.txt +++ b/test/data_for_tests/embedding/small_bert/vocab.txt @@ -18,3 +18,4 @@ for the whole text +##a \ No newline at end of file diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index fe4702ab..9cc0592f 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -3,6 +3,8 @@ from fastNLP import Vocabulary from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder import torch import os +from fastNLP import DataSet + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestDownload(unittest.TestCase): @@ -45,12 +47,41 @@ class TestBertEmbedding(unittest.TestCase): result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) + # 自动截断而不报错 + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, + only_use_pretrain_bpe=True, auto_truncate=True) + words = torch.LongTensor([[2, 3, 4, 0]*129]) + result = embed(words) + self.assertEqual(result.size(), (1, 516, 16)) + class TestBertWordPieceEncoder(unittest.TestCase): def test_bert_word_piece_encoder(self): embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) - from fastNLP import DataSet ds = DataSet({'words': ["this is a test . [SEP]".split()]}) embed.index_datasets(ds, field_name='words') self.assertTrue(ds.has_field('word_pieces')) result = embed(torch.LongTensor([[1,2,3,4]])) + + def test_bert_embed_eq_bert_piece_encoder(self): + ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) + encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert') + encoder.eval() + encoder.index_datasets(ds, field_name='words') + word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) + word_pieces_res = encoder(word_pieces) + + vocab = Vocabulary() + vocab.from_dataset(ds, field_name='words') + vocab.index_dataset(ds, field_name='words', new_field_name='words') + ds.set_input('words') + words = torch.LongTensor(ds['words'].get([0, 1])) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + pool_method='first', include_cls_sep=True, pooled_cls=False) + embed.eval() + words_res = embed(words) + + # 检查word piece什么的是正常work的 + self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) + self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) + self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) \ No newline at end of file