@@ -69,6 +69,7 @@ __all__ = [ | |||||
"LossFunc", | "LossFunc", | ||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"MSELoss", | |||||
"L1Loss", | "L1Loss", | ||||
"BCELoss", | "BCELoss", | ||||
"NLLLoss", | "NLLLoss", | ||||
@@ -81,7 +82,7 @@ __all__ = [ | |||||
'logger', | 'logger', | ||||
"init_logger_dist", | "init_logger_dist", | ||||
] | ] | ||||
__version__ = '0.5.0' | |||||
__version__ = '0.5.5' | |||||
import sys | import sys | ||||
@@ -65,6 +65,7 @@ __all__ = [ | |||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"CMRC2018Loss", | "CMRC2018Loss", | ||||
"MSELoss", | |||||
"LossBase", | "LossBase", | ||||
"MetricBase", | "MetricBase", | ||||
@@ -94,7 +95,8 @@ from .const import Const | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | |||||
LossInForward, CMRC2018Loss, LossBase, MSELoss | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | ||||
ConfusionMatrixMetric | ConfusionMatrixMetric | ||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
@@ -12,6 +12,7 @@ __all__ = [ | |||||
"BCELoss", | "BCELoss", | ||||
"L1Loss", | "L1Loss", | ||||
"NLLLoss", | "NLLLoss", | ||||
"MSELoss", | |||||
"CMRC2018Loss" | "CMRC2018Loss" | ||||
@@ -265,6 +266,26 @@ class L1Loss(LossBase): | |||||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | return F.l1_loss(input=pred, target=target, reduction=self.reduction) | ||||
class MSELoss(LossBase): | |||||
r""" | |||||
MSE损失函数 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(MSELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | |||||
return F.mse_loss(input=pred, target=target, reduction=self.reduction) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
r""" | r""" | ||||
二分类交叉熵损失函数 | 二分类交叉熵损失函数 | ||||
@@ -76,6 +76,8 @@ class BertForSequenceClassification(BaseModel): | |||||
hidden = self.dropout(self.bert(words)) | hidden = self.dropout(self.bert(words)) | ||||
cls_hidden = hidden[:, 0] | cls_hidden = hidden[:, 0] | ||||
logits = self.classifier(cls_hidden) | logits = self.classifier(cls_hidden) | ||||
if logits.size(-1) == 1: | |||||
logits = logits.squeeze(-1) | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
@@ -85,7 +87,10 @@ class BertForSequenceClassification(BaseModel): | |||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | ||||
""" | """ | ||||
logits = self.forward(words)[Const.OUTPUT] | logits = self.forward(words)[Const.OUTPUT] | ||||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||||
if self.num_labels > 1: | |||||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||||
else: | |||||
return {Const.OUTPUT: logits} | |||||
class BertForSentenceMatching(BaseModel): | class BertForSentenceMatching(BaseModel): | ||||