|
|
@@ -3,6 +3,8 @@ from fastNLP import Vocabulary |
|
|
|
from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder |
|
|
|
import torch |
|
|
|
import os |
|
|
|
from fastNLP import DataSet |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
class TestDownload(unittest.TestCase): |
|
|
@@ -45,12 +47,41 @@ class TestBertEmbedding(unittest.TestCase): |
|
|
|
result = embed(words) |
|
|
|
self.assertEqual(result.size(), (1, 4, 16)) |
|
|
|
|
|
|
|
# 自动截断而不报错 |
|
|
|
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, |
|
|
|
only_use_pretrain_bpe=True, auto_truncate=True) |
|
|
|
words = torch.LongTensor([[2, 3, 4, 0]*129]) |
|
|
|
result = embed(words) |
|
|
|
self.assertEqual(result.size(), (1, 516, 16)) |
|
|
|
|
|
|
|
|
|
|
|
class TestBertWordPieceEncoder(unittest.TestCase): |
|
|
|
def test_bert_word_piece_encoder(self): |
|
|
|
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) |
|
|
|
from fastNLP import DataSet |
|
|
|
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) |
|
|
|
embed.index_datasets(ds, field_name='words') |
|
|
|
self.assertTrue(ds.has_field('word_pieces')) |
|
|
|
result = embed(torch.LongTensor([[1,2,3,4]])) |
|
|
|
|
|
|
|
def test_bert_embed_eq_bert_piece_encoder(self): |
|
|
|
ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) |
|
|
|
encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert') |
|
|
|
encoder.eval() |
|
|
|
encoder.index_datasets(ds, field_name='words') |
|
|
|
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) |
|
|
|
word_pieces_res = encoder(word_pieces) |
|
|
|
|
|
|
|
vocab = Vocabulary() |
|
|
|
vocab.from_dataset(ds, field_name='words') |
|
|
|
vocab.index_dataset(ds, field_name='words', new_field_name='words') |
|
|
|
ds.set_input('words') |
|
|
|
words = torch.LongTensor(ds['words'].get([0, 1])) |
|
|
|
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', |
|
|
|
pool_method='first', include_cls_sep=True, pooled_cls=False) |
|
|
|
embed.eval() |
|
|
|
words_res = embed(words) |
|
|
|
|
|
|
|
# 检查word piece什么的是正常work的 |
|
|
|
self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) |
|
|
|
self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) |
|
|
|
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) |