|
|
@@ -6,6 +6,7 @@ import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from .base_model import BaseModel |
|
|
|
from ..core.const import Const |
|
|
|
from ..modules.encoder import BertModel |
|
|
|
|
|
|
|
|
|
|
@@ -62,13 +63,13 @@ class BertForSequenceClassification(BaseModel): |
|
|
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return {"pred": logits, "loss": loss} |
|
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
|
else: |
|
|
|
return {"pred": logits} |
|
|
|
return {Const.OUTPUT: logits} |
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask) |
|
|
|
return {"pred": torch.argmax(logits, dim=-1)} |
|
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
class BertForMultipleChoice(BaseModel): |
|
|
@@ -128,13 +129,13 @@ class BertForMultipleChoice(BaseModel): |
|
|
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
loss = loss_fct(reshaped_logits, labels) |
|
|
|
return {"pred": reshaped_logits, "loss": loss} |
|
|
|
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} |
|
|
|
else: |
|
|
|
return {"pred": reshaped_logits} |
|
|
|
return {Const.OUTPUT: reshaped_logits} |
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] |
|
|
|
return {"pred": torch.argmax(logits, dim=-1)} |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] |
|
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
class BertForTokenClassification(BaseModel): |
|
|
@@ -199,13 +200,13 @@ class BertForTokenClassification(BaseModel): |
|
|
|
loss = loss_fct(active_logits, active_labels) |
|
|
|
else: |
|
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return {"pred": logits, "loss": loss} |
|
|
|
return {Const.OUTPUT: logits, Const.LOSS: loss} |
|
|
|
else: |
|
|
|
return {"pred": logits} |
|
|
|
return {Const.OUTPUT: logits} |
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None): |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] |
|
|
|
return {"pred": torch.argmax(logits, dim=-1)} |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] |
|
|
|
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} |
|
|
|
|
|
|
|
|
|
|
|
class BertForQuestionAnswering(BaseModel): |
|
|
@@ -280,12 +281,13 @@ class BertForQuestionAnswering(BaseModel): |
|
|
|
start_loss = loss_fct(start_logits, start_positions) |
|
|
|
end_loss = loss_fct(end_logits, end_positions) |
|
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
return {"pred1": start_logits, "pred2": end_logits, "loss": total_loss} |
|
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} |
|
|
|
else: |
|
|
|
return {"pred1": start_logits, "pred2": end_logits} |
|
|
|
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} |
|
|
|
|
|
|
|
def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): |
|
|
|
logits = self.forward(input_ids, token_type_ids, attention_mask) |
|
|
|
start_logits = logits["pred1"] |
|
|
|
end_logits = logits["pred2"] |
|
|
|
return {"pred1": torch.argmax(start_logits, dim=-1), "pred2": torch.argmax(end_logits, dim=-1)} |
|
|
|
start_logits = logits[Const.OUTPUTS(0)] |
|
|
|
end_logits = logits[Const.OUTPUTS(1)] |
|
|
|
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), |
|
|
|
Const.OUTPUTS(1): torch.argmax(end_logits, dim=-1)} |