From a6dcf72e26842ae09ea2e0b9b4c6364e10db0156 Mon Sep 17 00:00:00 2001 From: nilyt Date: Tue, 15 Dec 2020 00:17:56 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8CrossEntropyLoss=E4=B8=AD=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0label=20smoothing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 642c8ef3..7f64fb93 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -209,7 +209,7 @@ class LossFunc(LossBase): class CrossEntropyLoss(LossBase): r""" 交叉熵损失函数 - + :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param seq_len: 句子的长度, 长度之外的token不会计算loss。 @@ -219,24 +219,27 @@ class CrossEntropyLoss(LossBase): 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 传入seq_len. + :param label_smoothing: label smoothing的eps系数,将标签编码为形如[1-eps, eps/(C-1), ..., eps/(C-1)]的向量(而不是oneone-hot向量),其中C是 + 标签类数,如果设置为None表示不进行label smoothing. :param str reduction: 支持 `mean` ,`sum` 和 `none` . Example:: loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) - + """ - - def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): + + def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, label_smoothing=None, reduction='mean'): super(CrossEntropyLoss, self).__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx assert reduction in ('mean', 'sum', 'none') self.reduction = reduction self.class_in_dim = class_in_dim - + self.label_smoothing = label_smoothing + def get_loss(self, pred, target, seq_len=None): - if seq_len is not None and target.dim()>1: + if seq_len is not None and target.dim() > 1: mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) target = target.masked_fill(mask, self.padding_idx) @@ -249,6 +252,20 @@ class CrossEntropyLoss(LossBase): pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) + if self.label_smoothing is not None: + n_class = pred.size(1) + one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1) + one_hot = one_hot * (1 - self.label_smoothing) + (1 - one_hot) * self.label_smoothing / (n_class - 1) + print(one_hot) + print(pred) + log_prb = F.log_softmax(pred, dim=1) + print(log_prb.shape) + + non_pad_mask = target.ne(self.padding_idx) + print(non_pad_mask.shape) + loss = -(one_hot * log_prb).sum(dim=1) + return getattr(loss.masked_select(non_pad_mask), self.reduction)() + return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduction)