Browse Source

[add] add MSELoss

tags/v0.6.0
Yige Xu 5 years ago
parent
commit
1f27d007d1
4 changed files with 32 additions and 3 deletions
  1. +2
    -1
      fastNLP/__init__.py
  2. +3
    -1
      fastNLP/core/__init__.py
  3. +21
    -0
      fastNLP/core/losses.py
  4. +6
    -1
      fastNLP/models/bert.py

+ 2
- 1
fastNLP/__init__.py View File

@@ -69,6 +69,7 @@ __all__ = [
"LossFunc",
"CrossEntropyLoss",
"MSELoss",
"L1Loss",
"BCELoss",
"NLLLoss",
@@ -81,7 +82,7 @@ __all__ = [
'logger',
"init_logger_dist",
]
__version__ = '0.5.0'
__version__ = '0.5.5'

import sys



+ 3
- 1
fastNLP/core/__init__.py View File

@@ -65,6 +65,7 @@ __all__ = [
"NLLLoss",
"LossInForward",
"CMRC2018Loss",
"MSELoss",
"LossBase",

"MetricBase",
@@ -94,7 +95,8 @@ from .const import Const
from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
LossInForward, CMRC2018Loss, LossBase, MSELoss
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
ConfusionMatrixMetric
from .optimizer import Optimizer, SGD, Adam, AdamW


+ 21
- 0
fastNLP/core/losses.py View File

@@ -12,6 +12,7 @@ __all__ = [
"BCELoss",
"L1Loss",
"NLLLoss",
"MSELoss",

"CMRC2018Loss"

@@ -265,6 +266,26 @@ class L1Loss(LossBase):
return F.l1_loss(input=pred, target=target, reduction=self.reduction)


class MSELoss(LossBase):
r"""
MSE损失函数

:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
:param str reduction: 支持'mean','sum'和'none'.

"""

def __init__(self, pred=None, target=None, reduction='mean'):
super(MSELoss, self).__init__()
self._init_param_map(pred=pred, target=target)
assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction

def get_loss(self, pred, target):
return F.mse_loss(input=pred, target=target, reduction=self.reduction)


class BCELoss(LossBase):
r"""
二分类交叉熵损失函数


+ 6
- 1
fastNLP/models/bert.py View File

@@ -76,6 +76,8 @@ class BertForSequenceClassification(BaseModel):
hidden = self.dropout(self.bert(words))
cls_hidden = hidden[:, 0]
logits = self.classifier(cls_hidden)
if logits.size(-1) == 1:
logits = logits.squeeze(-1)

return {Const.OUTPUT: logits}

@@ -85,7 +87,10 @@ class BertForSequenceClassification(BaseModel):
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size]
"""
logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
if self.num_labels > 1:
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
else:
return {Const.OUTPUT: logits}


class BertForSentenceMatching(BaseModel):


Loading…
Cancel
Save