Browse Source

Pre Merge pull request !6 from nilyt/N/A

pull/6/MERGE
nilyt Gitee 4 years ago
parent
commit
774f3b4036
1 changed files with 23 additions and 6 deletions
  1. +23
    -6
      fastNLP/core/losses.py

+ 23
- 6
fastNLP/core/losses.py View File

@@ -209,7 +209,7 @@ class LossFunc(LossBase):
class CrossEntropyLoss(LossBase):
r"""
交叉熵损失函数
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 句子的长度, 长度之外的token不会计算loss。
@@ -219,24 +219,27 @@ class CrossEntropyLoss(LossBase):
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
:param label_smoothing: label smoothing的eps系数,将标签编码为形如[1-eps, eps/(C-1), ..., eps/(C-1)]的向量(而不是oneone-hot向量),其中C是
标签类数,如果设置为None表示不进行label smoothing.
:param str reduction: 支持 `mean` ,`sum` 和 `none` .

Example::

loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0)
"""
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'):
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, label_smoothing=None, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction
self.class_in_dim = class_in_dim
self.label_smoothing = label_smoothing

def get_loss(self, pred, target, seq_len=None):
if seq_len is not None and target.dim()>1:
if seq_len is not None and target.dim() > 1:
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False)
target = target.masked_fill(mask, self.padding_idx)

@@ -249,6 +252,20 @@ class CrossEntropyLoss(LossBase):
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)

if self.label_smoothing is not None:
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - self.label_smoothing) + (1 - one_hot) * self.label_smoothing / (n_class - 1)
print(one_hot)
print(pred)
log_prb = F.log_softmax(pred, dim=1)
print(log_prb.shape)

non_pad_mask = target.ne(self.padding_idx)
print(non_pad_mask.shape)
loss = -(one_hot * log_prb).sum(dim=1)
return getattr(loss.masked_select(non_pad_mask), self.reduction)()

return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx, reduction=self.reduction)



Loading…
Cancel
Save