diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 2166734d..909e90a9 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase): self.class_in_dim = class_in_dim def get_loss(self, pred, target, seq_len=None): + if seq_len is not None and target.dim()>1: + mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) + target = target.masked_fill(mask, self.padding_idx) + if pred.dim() > 2: if self.class_in_dim == -1: if pred.size(1) != target.size(1): # 有可能顺序替换了 pred = pred.transpose(1, 2) else: - pred = pred.tranpose(-1, pred) + pred = pred.transpose(1, 2) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) - if seq_len is not None and target.dim()>1: - mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) - target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduction)