diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 5f18561a..4a963a46 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -69,6 +69,7 @@ __all__ = [ "LossFunc", "CrossEntropyLoss", + "MSELoss", "L1Loss", "BCELoss", "NLLLoss", @@ -81,7 +82,7 @@ __all__ = [ 'logger', "init_logger_dist", ] -__version__ = '0.5.0' +__version__ = '0.5.5' import sys diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index f4e42ab3..6eb3e424 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -65,6 +65,7 @@ __all__ = [ "NLLLoss", "LossInForward", "CMRC2018Loss", + "MSELoss", "LossBase", "MetricBase", @@ -94,7 +95,8 @@ from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .instance import Instance -from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase +from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ + LossInForward, CMRC2018Loss, LossBase, MSELoss from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ ConfusionMatrixMetric from .optimizer import Optimizer, SGD, Adam, AdamW diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 6788e6da..574738bb 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -12,6 +12,7 @@ __all__ = [ "BCELoss", "L1Loss", "NLLLoss", + "MSELoss", "CMRC2018Loss" @@ -265,6 +266,26 @@ class L1Loss(LossBase): return F.l1_loss(input=pred, target=target, reduction=self.reduction) +class MSELoss(LossBase): + r""" + MSE损失函数 + + :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` + :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` + :param str reduction: 支持'mean','sum'和'none'. + + """ + + def __init__(self, pred=None, target=None, reduction='mean'): + super(MSELoss, self).__init__() + self._init_param_map(pred=pred, target=target) + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction + + def get_loss(self, pred, target): + return F.mse_loss(input=pred, target=target, reduction=self.reduction) + + class BCELoss(LossBase): r""" 二分类交叉熵损失函数 diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 976b4638..827717d0 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -76,6 +76,8 @@ class BertForSequenceClassification(BaseModel): hidden = self.dropout(self.bert(words)) cls_hidden = hidden[:, 0] logits = self.classifier(cls_hidden) + if logits.size(-1) == 1: + logits = logits.squeeze(-1) return {Const.OUTPUT: logits} @@ -85,7 +87,10 @@ class BertForSequenceClassification(BaseModel): :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] """ logits = self.forward(words)[Const.OUTPUT] - return {Const.OUTPUT: torch.argmax(logits, dim=-1)} + if self.num_labels > 1: + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} + else: + return {Const.OUTPUT: logits} class BertForSentenceMatching(BaseModel):