|
|
@@ -2,7 +2,8 @@ import unittest |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.models.bert import * |
|
|
|
from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \ |
|
|
|
BertForTokenClassification, BertForMultipleChoice |
|
|
|
|
|
|
|
|
|
|
|
class TestBert(unittest.TestCase): |
|
|
@@ -14,9 +15,14 @@ class TestBert(unittest.TestCase): |
|
|
|
|
|
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) |
|
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) |
|
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) |
|
|
|
|
|
|
|
pred = model(input_ids, token_type_ids, input_mask) |
|
|
|
pred = model(input_ids, input_mask) |
|
|
|
self.assertTrue(isinstance(pred, dict)) |
|
|
|
self.assertTrue(Const.OUTPUT in pred) |
|
|
|
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) |
|
|
|
|
|
|
|
input_mask = torch.LongTensor([3, 2]) |
|
|
|
pred = model(input_ids, input_mask) |
|
|
|
self.assertTrue(isinstance(pred, dict)) |
|
|
|
self.assertTrue(Const.OUTPUT in pred) |
|
|
|
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) |
|
|
|