Browse Source

修复CrossEntropyLoss在seq_len不为None的时候会出现计算错误

tags/v0.5.5
yh 5 years ago
parent
commit
1eec5b234b
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      fastNLP/core/losses.py

+ 5
- 4
fastNLP/core/losses.py View File

@@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase):
self.class_in_dim = class_in_dim
def get_loss(self, pred, target, seq_len=None):
if seq_len is not None and target.dim()>1:
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0)
target = target.masked_fill(mask, self.padding_idx)

if pred.dim() > 2:
if self.class_in_dim == -1:
if pred.size(1) != target.size(1): # 有可能顺序替换了
pred = pred.transpose(1, 2)
else:
pred = pred.tranpose(-1, pred)
pred = pred.transpose(1, 2)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None and target.dim()>1:
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0)
target = target.masked_fill(mask, self.padding_idx)

return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx, reduction=self.reduction)


Loading…
Cancel
Save