|
- import unittest
-
- import torch
-
- from fastNLP.core import Vocabulary, Const
- from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
- BertForTokenClassification, BertForMultipleChoice, BertForSentenceMatching
- from fastNLP.embeddings.bert_embedding import BertEmbedding
-
-
- class TestBert(unittest.TestCase):
- def test_bert_1(self):
- vocab = Vocabulary().add_word_lst("this is a test .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=True)
-
- model = BertForSequenceClassification(embed, 2)
-
- input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
-
- def test_bert_1_w(self):
- vocab = Vocabulary().add_word_lst("this is a test .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=False)
-
- with self.assertWarns(Warning):
- model = BertForSequenceClassification(embed, 2)
-
- input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
-
- pred = model.predict(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))
-
- def test_bert_2(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=True)
-
- model = BertForMultipleChoice(embed, 2)
-
- input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
- print(input_ids.size())
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
-
- def test_bert_2_w(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=False)
-
- with self.assertWarns(Warning):
- model = BertForMultipleChoice(embed, 2)
-
- input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
- print(input_ids.size())
-
- pred = model.predict(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1,))
-
- def test_bert_3(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=False)
- model = BertForTokenClassification(embed, 7)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
-
- def test_bert_3_w(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=True)
-
- with self.assertWarns(Warning):
- model = BertForTokenClassification(embed, 7)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model.predict(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3))
-
- def test_bert_4(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=True)
- model = BertForQuestionAnswering(embed)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUTS(0) in pred)
- self.assertTrue(Const.OUTPUTS(1) in pred)
- self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 5))
- self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 5))
-
- model = BertForQuestionAnswering(embed, 7)
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertEqual(len(pred), 7)
-
- def test_bert_4_w(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=False)
-
- with self.assertWarns(Warning):
- model = BertForQuestionAnswering(embed)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model.predict(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUTS(1) in pred)
- self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2,))
-
- def test_bert_5(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=True)
- model = BertForSentenceMatching(embed)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
-
- def test_bert_5_w(self):
-
- vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
- embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
- include_cls_sep=False)
-
- with self.assertWarns(Warning):
- model = BertForSentenceMatching(embed)
-
- input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
-
- pred = model.predict(input_ids)
- self.assertTrue(isinstance(pred, dict))
- self.assertTrue(Const.OUTPUT in pred)
- self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))
|