From 1eec5b234b3bf938dd9affde22ee5460493f7421 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 19:14:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCrossEntropyLoss=E5=9C=A8seq?= =?UTF-8?q?=5Flen=E4=B8=8D=E4=B8=BANone=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E4=BC=9A=E5=87=BA=E7=8E=B0=E8=AE=A1=E7=AE=97=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 2166734d..909e90a9 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase): self.class_in_dim = class_in_dim def get_loss(self, pred, target, seq_len=None): + if seq_len is not None and target.dim()>1: + mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) + target = target.masked_fill(mask, self.padding_idx) + if pred.dim() > 2: if self.class_in_dim == -1: if pred.size(1) != target.size(1): # 有可能顺序替换了 pred = pred.transpose(1, 2) else: - pred = pred.tranpose(-1, pred) + pred = pred.transpose(1, 2) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) - if seq_len is not None and target.dim()>1: - mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) - target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduction)