Browse Source

update test codes in models/bert.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
9560a4d367
2 changed files with 14 additions and 3 deletions
  1. +5
    -0
      fastNLP/models/bert.py
  2. +9
    -3
      test/models/test_bert.py

+ 5
- 0
fastNLP/models/bert.py View File

@@ -10,6 +10,7 @@ from .base_model import BaseModel
from ..core.const import Const
from ..modules.encoder import BertModel
from ..modules.encoder.bert import BertConfig, CONFIG_FILE
from ..core.utils import seq_len_to_mask


class BertForSequenceClassification(BaseModel):
@@ -70,6 +71,10 @@ class BertForSequenceClassification(BaseModel):
return model

def forward(self, words, seq_len=None, target=None):
if seq_len is None:
seq_len = torch.ones_like(words, dtype=words.dtype, device=words.device)
if len(seq_len.size()) + 1 == len(words.size()):
seq_len = seq_len_to_mask(seq_len, max_len=words.size(-1))
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)


+ 9
- 3
test/models/test_bert.py View File

@@ -2,7 +2,8 @@ import unittest

import torch

from fastNLP.models.bert import *
from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
BertForTokenClassification, BertForMultipleChoice


class TestBert(unittest.TestCase):
@@ -14,9 +15,14 @@ class TestBert(unittest.TestCase):

input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

pred = model(input_ids, token_type_ids, input_mask)
pred = model(input_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))

input_mask = torch.LongTensor([3, 2])
pred = model(input_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))


Loading…
Cancel
Save