diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 7934b435..960132ad 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -6,6 +6,7 @@ import torch from torch import nn from .base_model import BaseModel +from ..core.const import Const from ..modules.encoder import BertModel @@ -62,13 +63,13 @@ class BertForSequenceClassification(BaseModel): if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return {"pred": logits, "loss": loss} + return {Const.OUTPUT: logits, Const.LOSS: loss} else: - return {"pred": logits} + return {Const.OUTPUT: logits} def predict(self, input_ids, token_type_ids=None, attention_mask=None): logits = self.forward(input_ids, token_type_ids, attention_mask) - return {"pred": torch.argmax(logits, dim=-1)} + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} class BertForMultipleChoice(BaseModel): @@ -128,13 +129,13 @@ class BertForMultipleChoice(BaseModel): if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) - return {"pred": reshaped_logits, "loss": loss} + return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} else: - return {"pred": reshaped_logits} + return {Const.OUTPUT: reshaped_logits} def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] - return {"pred": torch.argmax(logits, dim=-1)} + logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} class BertForTokenClassification(BaseModel): @@ -199,13 +200,13 @@ class BertForTokenClassification(BaseModel): loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return {"pred": logits, "loss": loss} + return {Const.OUTPUT: logits, Const.LOSS: loss} else: - return {"pred": logits} + return {Const.OUTPUT: logits} def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] - return {"pred": torch.argmax(logits, dim=-1)} + logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} class BertForQuestionAnswering(BaseModel): @@ -280,12 +281,13 @@ class BertForQuestionAnswering(BaseModel): start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - return {"pred1": start_logits, "pred2": end_logits, "loss": total_loss} + return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} else: - return {"pred1": start_logits, "pred2": end_logits} + return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): logits = self.forward(input_ids, token_type_ids, attention_mask) - start_logits = logits["pred1"] - end_logits = logits["pred2"] - return {"pred1": torch.argmax(start_logits, dim=-1), "pred2": torch.argmax(end_logits, dim=-1)} + start_logits = logits[Const.OUTPUTS(0)] + end_logits = logits[Const.OUTPUTS(1)] + return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), + Const.OUTPUTS(1): torch.argmax(end_logits, dim=-1)}