@@ -10,6 +10,7 @@ from .base_model import BaseModel | |||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
from ..modules.encoder.bert import BertConfig, CONFIG_FILE | from ..modules.encoder.bert import BertConfig, CONFIG_FILE | ||||
from ..core.utils import seq_len_to_mask | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -70,6 +71,10 @@ class BertForSequenceClassification(BaseModel): | |||||
return model | return model | ||||
def forward(self, words, seq_len=None, target=None): | def forward(self, words, seq_len=None, target=None): | ||||
if seq_len is None: | |||||
seq_len = torch.ones_like(words, dtype=words.dtype, device=words.device) | |||||
if len(seq_len.size()) + 1 == len(words.size()): | |||||
seq_len = seq_len_to_mask(seq_len, max_len=words.size(-1)) | |||||
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) | _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) | ||||
pooled_output = self.dropout(pooled_output) | pooled_output = self.dropout(pooled_output) | ||||
logits = self.classifier(pooled_output) | logits = self.classifier(pooled_output) | ||||
@@ -2,7 +2,8 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.models.bert import * | |||||
from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \ | |||||
BertForTokenClassification, BertForMultipleChoice | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
@@ -14,9 +15,14 @@ class TestBert(unittest.TestCase): | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
pred = model(input_ids, token_type_ids, input_mask) | |||||
pred = model(input_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUT in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) | |||||
input_mask = torch.LongTensor([3, 2]) | |||||
pred = model(input_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | self.assertTrue(isinstance(pred, dict)) | ||||
self.assertTrue(Const.OUTPUT in pred) | self.assertTrue(Const.OUTPUT in pred) | ||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) | self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) | ||||