Browse Source

update const files

tags/v0.4.10
xuyige 6 years ago
parent
commit
1d46ece326
1 changed files with 18 additions and 16 deletions
  1. +18
    -16
      fastNLP/models/bert.py

+ 18
- 16
fastNLP/models/bert.py View File

@@ -6,6 +6,7 @@ import torch
from torch import nn

from .base_model import BaseModel
from ..core.const import Const
from ..modules.encoder import BertModel


@@ -62,13 +63,13 @@ class BertForSequenceClassification(BaseModel):
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return {"pred": logits, "loss": loss}
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {"pred": logits}
return {Const.OUTPUT: logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)
return {"pred": torch.argmax(logits, dim=-1)}
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForMultipleChoice(BaseModel):
@@ -128,13 +129,13 @@ class BertForMultipleChoice(BaseModel):
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return {"pred": reshaped_logits, "loss": loss}
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss}
else:
return {"pred": reshaped_logits}
return {Const.OUTPUT: reshaped_logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"]
return {"pred": torch.argmax(logits, dim=-1)}
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForTokenClassification(BaseModel):
@@ -199,13 +200,13 @@ class BertForTokenClassification(BaseModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return {"pred": logits, "loss": loss}
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {"pred": logits}
return {Const.OUTPUT: logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None):
logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"]
return {"pred": torch.argmax(logits, dim=-1)}
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForQuestionAnswering(BaseModel):
@@ -280,12 +281,13 @@ class BertForQuestionAnswering(BaseModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return {"pred1": start_logits, "pred2": end_logits, "loss": total_loss}
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss}
else:
return {"pred1": start_logits, "pred2": end_logits}
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits}

def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs):
logits = self.forward(input_ids, token_type_ids, attention_mask)
start_logits = logits["pred1"]
end_logits = logits["pred2"]
return {"pred1": torch.argmax(start_logits, dim=-1), "pred2": torch.argmax(end_logits, dim=-1)}
start_logits = logits[Const.OUTPUTS(0)]
end_logits = logits[Const.OUTPUTS(1)]
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1),
Const.OUTPUTS(1): torch.argmax(end_logits, dim=-1)}

Loading…
Cancel
Save