From 016f02be3b3db52d08790f8e053fcea99858e85d Mon Sep 17 00:00:00 2001 From: xuyige Date: Wed, 29 May 2019 14:46:48 +0800 Subject: [PATCH] fix bugs in model/bert.py and add testing codes --- fastNLP/models/bert.py | 71 ++++++++++++++++++++++++++----- test/models/test_bert.py | 60 ++++++++++++++++++++++---- test/modules/encoder/test_bert.py | 21 +++++++++ 3 files changed, 133 insertions(+), 19 deletions(-) create mode 100644 test/modules/encoder/test_bert.py diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 960132ad..02227c0d 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -10,6 +10,35 @@ from ..core.const import Const from ..modules.encoder import BertModel +class BertConfig: + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02 + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + + class BertForSequenceClassification(BaseModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of @@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel): config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 - model = BertForSequenceClassification(config, num_labels) + model = BertForSequenceClassification(num_labels, config) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels, bert_dir): + def __init__(self, num_labels, config=None, bert_dir=None): super(BertForSequenceClassification, self).__init__() self.num_labels = num_labels - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) @@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel): config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_choices = 2 - model = BertForMultipleChoice(config, num_choices, bert_dir) + model = BertForMultipleChoice(num_choices, config, bert_dir) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_choices, bert_dir): + def __init__(self, num_choices, config=None, bert_dir=None): super(BertForMultipleChoice, self).__init__() self.num_choices = num_choices - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) @@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel): num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 bert_dir = 'your-bert-file-dir' - model = BertForTokenClassification(config, num_labels, bert_dir) + model = BertForTokenClassification(num_labels, config, bert_dir) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels, bert_dir): + def __init__(self, num_labels, config=None, bert_dir=None): super(BertForTokenClassification, self).__init__() self.num_labels = num_labels - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) @@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, bert_dir): + def __init__(self, config=None, bert_dir=None): super(BertForQuestionAnswering, self).__init__() - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.qa_outputs = nn.Linear(config.hidden_size, 2) diff --git a/test/models/test_bert.py b/test/models/test_bert.py index b2899a89..7177f31b 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -2,20 +2,64 @@ import unittest import torch -from fastNLP.models.bert import BertModel +from fastNLP.models.bert import * class TestBert(unittest.TestCase): def test_bert_1(self): - # model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese") - model = BertModel(vocab_size=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + from fastNLP.core.const import Const + + model = BertForSequenceClassification(2) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + pred = model(input_ids, token_type_ids, input_mask) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUT in pred) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) + + def test_bert_2(self): + from fastNLP.core.const import Const + + model = BertForMultipleChoice(2) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + pred = model(input_ids, token_type_ids, input_mask) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUT in pred) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2)) + + def test_bert_3(self): + from fastNLP.core.const import Const + + model = BertForTokenClassification(7) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + pred = model(input_ids, token_type_ids, input_mask) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUT in pred) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7)) + + def test_bert_4(self): + from fastNLP.core.const import Const + + model = BertForQuestionAnswering() input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - for layer in all_encoder_layers: - self.assertEqual(tuple(layer.shape), (2, 3, 768)) - self.assertEqual(tuple(pooled_output.shape), (2, 768)) + pred = model(input_ids, token_type_ids, input_mask) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUTS(0) in pred) + self.assertTrue(Const.OUTPUTS(1) in pred) + self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3)) + self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3)) diff --git a/test/modules/encoder/test_bert.py b/test/modules/encoder/test_bert.py new file mode 100644 index 00000000..78bcf633 --- /dev/null +++ b/test/modules/encoder/test_bert.py @@ -0,0 +1,21 @@ + +import unittest + +import torch + +from fastNLP.models.bert import BertModel + + +class TestBert(unittest.TestCase): + def test_bert_1(self): + model = BertModel(vocab_size=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + for layer in all_encoder_layers: + self.assertEqual(tuple(layer.shape), (2, 3, 768)) + self.assertEqual(tuple(pooled_output.shape), (2, 768)) \ No newline at end of file