diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index adecab60..ad7750ec 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -2,13 +2,14 @@ bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. """ +import os import torch from torch import nn from .base_model import BaseModel from ..core.const import Const from ..modules.encoder import BertModel -from ..modules.encoder.bert import BertConfig +from ..modules.encoder.bert import BertConfig, CONFIG_FILE class BertForSequenceClassification(BaseModel): @@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): self.num_labels = num_labels if bert_dir is not None: self.bert = BertModel.from_pretrained(bert_dir) + config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) else: if config is None: config = BertConfig(30522) @@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len=None, target=None): + _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits, target) return {Const.OUTPUT: logits, Const.LOSS: loss} else: return {Const.OUTPUT: logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask) + def predict(self, words, seq_len=None): + logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + def forward(self, words, seq_len1=None, seq_len2=None, target=None): + input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) @@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, self.num_choices) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) + loss = loss_fct(reshaped_logits, target) return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} else: return {Const.OUTPUT: reshaped_logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + def predict(self, words, seq_len1=None, seq_len2=None,): + logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len1=None, seq_len2=None, target=None): + sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 + if seq_len2 is not None: + active_loss = seq_len2.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] + active_labels = target.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) return {Const.OUTPUT: logits, Const.LOSS: loss} else: return {Const.OUTPUT: logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + def predict(self, words, seq_len1=None, seq_len2=None): + logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): model = cls(config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): + sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) - if start_positions is not None and end_positions is not None: + if target1 is not None and target2 is not None: # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) + if len(target1.size()) > 1: + target1 = target1.squeeze(-1) + if len(target2.size()) > 1: + target2 = target2.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) + target1.clamp_(0, ignored_index) + target2.clamp_(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) + start_loss = loss_fct(start_logits, target1) + end_loss = loss_fct(end_logits, target2) total_loss = (start_loss + end_loss) / 2 return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} else: return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): - logits = self.forward(input_ids, token_type_ids, attention_mask) + def predict(self, words, seq_len1=None, seq_len2=None): + logits = self.forward(words, seq_len1, seq_len2) start_logits = logits[Const.OUTPUTS(0)] end_logits = logits[Const.OUTPUTS(1)] return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1),