diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index d1a5514a..f6c36623 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -393,7 +393,7 @@ class _WordBertModel(nn.Module): batch_indexes = torch.arange(batch_size).to(words) word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids - sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len + sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 0a89b765..08f16db2 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -5,253 +5,145 @@ bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed __all__ = [] -import os +import warnings import torch from torch import nn from .base_model import BaseModel from ..core.const import Const -from ..core.utils import seq_len_to_mask +from ..core._logger import logger from ..modules.encoder import BertModel from ..modules.encoder.bert import BertConfig, CONFIG_FILE +from ..embeddings.bert_embedding import BertEmbedding class BertForSequenceClassification(BaseModel): """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - num_labels = 2 - model = BertForSequenceClassification(num_labels, config) - logits = model(input_ids, token_type_ids, input_mask) - ``` """ - def __init__(self, num_labels, config=None, bert_dir=None): + def __init__(self, init_embed: BertEmbedding, num_labels: int=2): super(BertForSequenceClassification, self).__init__() + self.num_labels = num_labels - if bert_dir is not None: - self.bert = BertModel.from_pretrained(bert_dir) - config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) - else: - if config is None: - config = BertConfig(30522) - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - - @classmethod - def from_pretrained(cls, num_labels, pretrained_model_dir): - config = BertConfig(pretrained_model_dir) - model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) - return model - - def forward(self, words, seq_len=None, target=None): - if seq_len is None: - seq_len = torch.ones_like(words, dtype=words.dtype, device=words.device) - if len(seq_len.size()) + 1 == len(words.size()): - seq_len = seq_len_to_mask(seq_len, max_len=words.size(-1)) - _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) + self.bert = init_embed + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(self.bert.embedding_dim, num_labels) + + if not self.bert.model.include_cls_sep: + warn_msg = "Bert for sequence classification excepts BertEmbedding `include_cls_sep` True, but got False." + logger.warn(warn_msg) + warnings.warn(warn_msg) + + def forward(self, words): + hidden = self.dropout(self.bert(words)) + cls_hidden = hidden[:, 0] + logits = self.classifier(cls_hidden) + + return {Const.OUTPUT: logits} + + def predict(self, words): + logits = self.forward(words)[Const.OUTPUT] + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} + + +class BertForSentenceMatching(BaseModel): + + """BERT model for matching. + """ + def __init__(self, init_embed: BertEmbedding, num_labels: int=2): + super(BertForSentenceMatching, self).__init__() + self.num_labels = num_labels + self.bert = init_embed + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(self.bert.embedding_dim, num_labels) + + if not self.bert.model.include_cls_sep: + error_msg = "Bert for sentence matching excepts BertEmbedding `include_cls_sep` True, but got False." + logger.error(error_msg) + raise RuntimeError(error_msg) - if target is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits, target) - return {Const.OUTPUT: logits, Const.LOSS: loss} - else: - return {Const.OUTPUT: logits} + def forward(self, words): + hidden = self.dropout(self.bert(words)) + cls_hidden = hidden[:, 0] + logits = self.classifier(cls_hidden) - def predict(self, words, seq_len=None): - logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] + return {Const.OUTPUT: logits} + + def predict(self, words): + logits = self.forward(words)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} class BertForMultipleChoice(BaseModel): """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - num_choices = 2 - model = BertForMultipleChoice(num_choices, config, bert_dir) - logits = model(input_ids, token_type_ids, input_mask) - ``` """ - def __init__(self, num_choices, config=None, bert_dir=None): + def __init__(self, init_embed: BertEmbedding, num_choices=2): super(BertForMultipleChoice, self).__init__() + self.num_choices = num_choices - if bert_dir is not None: - self.bert = BertModel.from_pretrained(bert_dir) - else: - if config is None: - config = BertConfig(30522) - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - - @classmethod - def from_pretrained(cls, num_choices, pretrained_model_dir): - config = BertConfig(pretrained_model_dir) - model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) - return model - - def forward(self, words, seq_len1=None, seq_len2=None, target=None): - input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + self.bert = init_embed + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(self.bert.embedding_dim, 1) + self.include_cls_sep = init_embed.model.include_cls_sep + + if not self.bert.model.include_cls_sep: + error_msg = "Bert for multiple choice excepts BertEmbedding `include_cls_sep` True, but got False." + logger.error(error_msg) + raise RuntimeError(error_msg) + + def forward(self, words): + """ + :param torch.Tensor words: [batch_size, num_choices, seq_len] + :return: [batch_size, num_labels] + """ + batch_size, num_choices, seq_len = words.size() + + input_ids = words.view(batch_size * num_choices, seq_len) + hidden = self.bert(input_ids) + pooled_output = hidden[:, 0] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, self.num_choices) - if target is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, target) - return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} - else: - return {Const.OUTPUT: reshaped_logits} + return {Const.OUTPUT: reshaped_logits} - def predict(self, words, seq_len1=None, seq_len2=None,): - logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] + def predict(self, words): + logits = self.forward(words)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} class BertForTokenClassification(BaseModel): """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - `bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin` - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [0, ..., num_labels]. - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - num_labels = 2 - bert_dir = 'your-bert-file-dir' - model = BertForTokenClassification(num_labels, config, bert_dir) - logits = model(input_ids, token_type_ids, input_mask) - ``` """ - def __init__(self, num_labels, config=None, bert_dir=None): + def __init__(self, init_embed: BertEmbedding, num_labels): super(BertForTokenClassification, self).__init__() + self.num_labels = num_labels - if bert_dir is not None: - self.bert = BertModel.from_pretrained(bert_dir) - else: - if config is None: - config = BertConfig(30522) - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - - @classmethod - def from_pretrained(cls, num_labels, pretrained_model_dir): - config = BertConfig(pretrained_model_dir) - model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) - return model - - def forward(self, words, seq_len1=None, seq_len2=None, target=None): - sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) + self.bert = init_embed + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(self.bert.embedding_dim, num_labels) + self.include_cls_sep = init_embed.model.include_cls_sep + + if self.include_cls_sep: + warn_msg = "Bert for token classification excepts BertEmbedding `include_cls_sep` False, but got True." + warnings.warn(warn_msg) + logger.warn(warn_msg) + + def forward(self, words): + """ + :param torch.Tensor words: [batch_size, seq_len] + :return: [batch_size, seq_len, num_labels] + """ + sequence_output = self.bert(words) + if self.include_cls_sep: + sequence_output = sequence_output[:, 1: -1] # [batch_size, seq_len, embed_dim] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - if target is not None: - loss_fct = nn.CrossEntropyLoss() - # Only keep active parts of the loss - if seq_len2 is not None: - active_loss = seq_len2.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = target.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) - return {Const.OUTPUT: logits, Const.LOSS: loss} - else: - return {Const.OUTPUT: logits} - - def predict(self, words, seq_len1=None, seq_len2=None): - logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] + return {Const.OUTPUT: logits} + + def predict(self, words): + logits = self.forward(words)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -298,53 +190,24 @@ class BertForQuestionAnswering(BaseModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config=None, bert_dir=None): + def __init__(self, init_embed: BertEmbedding, num_labels=2): super(BertForQuestionAnswering, self).__init__() - if bert_dir is not None: - self.bert = BertModel.from_pretrained(bert_dir) - else: - if config is None: - config = BertConfig(30522) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - @classmethod - def from_pretrained(cls, pretrained_model_dir): - config = BertConfig(pretrained_model_dir) - model = cls(config=config, bert_dir=pretrained_model_dir) - return model - - def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): - sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if target1 is not None and target2 is not None: - # If we are on multi-GPU, split add a dimension - if len(target1.size()) > 1: - target1 = target1.squeeze(-1) - if len(target2.size()) > 1: - target2 = target2.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - target1.clamp_(0, ignored_index) - target2.clamp_(0, ignored_index) - - loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, target1) - end_loss = loss_fct(end_logits, target2) - total_loss = (start_loss + end_loss) / 2 - return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} - else: - return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} - - def predict(self, words, seq_len1=None, seq_len2=None): - logits = self.forward(words, seq_len1, seq_len2) - start_logits = logits[Const.OUTPUTS(0)] - end_logits = logits[Const.OUTPUTS(1)] - return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), - Const.OUTPUTS(1): torch.argmax(end_logits, dim=-1)} + + self.bert = init_embed + self.num_labels = num_labels + self.qa_outputs = nn.Linear(self.bert.embedding_dim, self.num_labels) + + if not self.bert.model.include_cls_sep: + error_msg = "Bert for multiple choice excepts BertEmbedding `include_cls_sep` True, but got False." + logger.error(error_msg) + raise RuntimeError(error_msg) + + def forward(self, words): + sequence_output = self.bert(words) + logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] + + return {Const.OUTPUTS(i): logits[:, :, i] for i in range(self.num_labels)} + + def predict(self, words): + logits = self.forward(words) + return {Const.OUTPUTS(i): torch.argmax(logits[Const.OUTPUTS(i)], dim=-1) for i in range(self.num_labels)} diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index e73a8172..6f6c4291 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -435,14 +435,14 @@ class BertModel(nn.Module): return encoded_layers, pooled_output @classmethod - def from_pretrained(cls, pretrained_model_dir_or_name, *inputs, **kwargs): + def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): state_dict = kwargs.get('state_dict', None) kwargs.pop('state_dict', None) kwargs.pop('cache_dir', None) kwargs.pop('from_tf', None) # get model dir from name or dir - pretrained_model_dir = _get_bert_dir(pretrained_model_dir_or_name) + pretrained_model_dir = _get_bert_dir(model_dir_or_name) # Load config config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index c7a5f082..00000000 --- a/test/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import fastNLP - -__all__ = ["fastNLP"] diff --git a/test/data_for_tests/embedding/small_bert/config.json b/test/data_for_tests/embedding/small_bert/config.json new file mode 100644 index 00000000..3e516872 --- /dev/null +++ b/test/data_for_tests/embedding/small_bert/config.json @@ -0,0 +1,13 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 64, + "max_position_embeddings": 32, + "num_attention_heads": 4, + "num_hidden_layers": 2, + "type_vocab_size": 2, + "vocab_size": 20 +} \ No newline at end of file diff --git a/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin b/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin new file mode 100644 index 00000000..fe968fb5 Binary files /dev/null and b/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin differ diff --git a/test/data_for_tests/embedding/small_bert/vocab.txt b/test/data_for_tests/embedding/small_bert/vocab.txt new file mode 100644 index 00000000..565e67af --- /dev/null +++ b/test/data_for_tests/embedding/small_bert/vocab.txt @@ -0,0 +1,20 @@ +[PAD] +[UNK] +[CLS] +[SEP] +this +is +a +small +bert +model +vocab +file +and +only +twenty +line +for +the +whole +text diff --git a/test/data_for_tests/glove.6B.50d_test.txt b/test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt similarity index 100% rename from test/data_for_tests/glove.6B.50d_test.txt rename to test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt diff --git a/test/data_for_tests/word2vec_test.txt b/test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt similarity index 100% rename from test/data_for_tests/word2vec_test.txt rename to test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt diff --git a/test/embeddings/__init__.py b/test/embeddings/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index da81c8c9..46ad74c3 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -18,4 +18,13 @@ class TestDownload(unittest.TestCase): embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) for i in range(10): words = torch.LongTensor([[2, 3, 4, 0]]) - print(embed(words).size()) \ No newline at end of file + print(embed(words).size()) + + +class TestBertEmbedding(unittest.TestCase): + def test_bert_embedding_1(self): + vocab = Vocabulary().add_word_lst("this is a test .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert') + words = torch.LongTensor([[2, 3, 4, 0]]) + result = embed(words) + self.assertEqual(result.size(), (1, 4, 16)) diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index c17daa0a..7d1e8302 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -10,7 +10,8 @@ class TestLoad(unittest.TestCase): def test_norm1(self): # 测试只对可以找到的norm vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) - embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', only_norm_found_vector=True) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1) @@ -18,7 +19,8 @@ class TestLoad(unittest.TestCase): def test_norm2(self): # 测试对所有都norm vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) - embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', normalize=True) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py index bbfe8858..70b367ec 100644 --- a/test/io/test_embed_loader.py +++ b/test/io/test_embed_loader.py @@ -8,8 +8,8 @@ from fastNLP.io import EmbedLoader class TestEmbedLoader(unittest.TestCase): def test_load_with_vocab(self): vocab = Vocabulary() - glove = "test/data_for_tests/glove.6B.50d_test.txt" - word2vec = "test/data_for_tests/word2vec_test.txt" + glove = "test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" + word2vec = "test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" vocab.add_word('the') vocab.add_word('none') g_m = EmbedLoader.load_with_vocab(glove, vocab) @@ -20,8 +20,8 @@ class TestEmbedLoader(unittest.TestCase): def test_load_without_vocab(self): words = ['the', 'of', 'in', 'a', 'to', 'and'] - glove = "test/data_for_tests/glove.6B.50d_test.txt" - word2vec = "test/data_for_tests/word2vec_test.txt" + glove = "test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" + word2vec = "test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" g_m, vocab = EmbedLoader.load_without_vocab(glove) self.assertEqual(g_m.shape, (8, 50)) for word in words: diff --git a/test/models/__init__.py b/test/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/models/test_bert.py b/test/models/test_bert.py index 40b98c81..2b310edf 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -2,74 +2,94 @@ import unittest import torch +from fastNLP.core import Vocabulary, Const from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \ - BertForTokenClassification, BertForMultipleChoice + BertForTokenClassification, BertForMultipleChoice, BertForSentenceMatching +from fastNLP.embeddings.bert_embedding import BertEmbedding class TestBert(unittest.TestCase): def test_bert_1(self): - from fastNLP.core.const import Const - from fastNLP.modules.encoder.bert import BertConfig + vocab = Vocabulary().add_word_lst("this is a test .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + include_cls_sep=True) - model = BertForSequenceClassification(2, BertConfig(32000)) + model = BertForSequenceClassification(embed, 2) - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]]) - pred = model(input_ids, input_mask) + pred = model(input_ids) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUT in pred) self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) - input_mask = torch.LongTensor([3, 2]) - pred = model(input_ids, input_mask) + pred = model.predict(input_ids) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUT in pred) - self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,)) def test_bert_2(self): - from fastNLP.core.const import Const - from fastNLP.modules.encoder.bert import BertConfig - model = BertForMultipleChoice(2, BertConfig(32000)) + vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + include_cls_sep=True) + + model = BertForMultipleChoice(embed, 2) - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]]) + print(input_ids.size()) - pred = model(input_ids, token_type_ids, input_mask) + pred = model(input_ids) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUT in pred) self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2)) def test_bert_3(self): - from fastNLP.core.const import Const - from fastNLP.modules.encoder.bert import BertConfig - model = BertForTokenClassification(7, BertConfig(32000)) + vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + include_cls_sep=False) + model = BertForTokenClassification(embed, 7) - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) - pred = model(input_ids, token_type_ids, input_mask) + pred = model(input_ids) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUT in pred) self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7)) def test_bert_4(self): - from fastNLP.core.const import Const - from fastNLP.modules.encoder.bert import BertConfig - model = BertForQuestionAnswering(BertConfig(32000)) + vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', + include_cls_sep=True) + model = BertForQuestionAnswering(embed) - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) - pred = model(input_ids, token_type_ids, input_mask) + pred = model(input_ids) self.assertTrue(isinstance(pred, dict)) self.assertTrue(Const.OUTPUTS(0) in pred) self.assertTrue(Const.OUTPUTS(1) in pred) - self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3)) - self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3)) + self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 5)) + self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 5)) + + model = BertForQuestionAnswering(embed, 7) + pred = model(input_ids) + self.assertTrue(isinstance(pred, dict)) + self.assertEqual(len(pred), 7) + + def test_bert_5(self): + + vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='./../data_for_tests/embedding/small_bert', + include_cls_sep=True) + model = BertForSentenceMatching(embed) + + input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) + + pred = model(input_ids) + self.assertTrue(isinstance(pred, dict)) + self.assertTrue(Const.OUTPUT in pred) + self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) + diff --git a/test/modules/__init__.py b/test/modules/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/modules/decoder/__init__.py b/test/modules/decoder/__init__.py deleted file mode 100644 index e69de29b..00000000