diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index ad7750ec..3afccc14 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -10,6 +10,7 @@ from .base_model import BaseModel from ..core.const import Const from ..modules.encoder import BertModel from ..modules.encoder.bert import BertConfig, CONFIG_FILE +from ..core.utils import seq_len_to_mask class BertForSequenceClassification(BaseModel): @@ -70,6 +71,10 @@ class BertForSequenceClassification(BaseModel): return model def forward(self, words, seq_len=None, target=None): + if seq_len is None: + seq_len = torch.ones_like(words, dtype=words.dtype, device=words.device) + if len(seq_len.size()) + 1 == len(words.size()): + seq_len = seq_len_to_mask(seq_len, max_len=words.size(-1)) _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) diff --git a/test/models/test_bert.py b/test/models/test_bert.py index 05ee6d5a..40b98c81 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -2,7 +2,8 @@ import unittest import torch -from fastNLP.models.bert import * +from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \ + BertForTokenClassification, BertForMultipleChoice class TestBert(unittest.TestCase): @@ -14,9 +15,14 @@ class TestBert(unittest.TestCase): input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - pred = model(input_ids, token_type_ids, input_mask) + pred = model(input_ids, input_mask) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUT in pred) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) + + input_mask = torch.LongTensor([3, 2]) + pred = model(input_ids, input_mask) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUT in pred) self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))