|
|
@@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase): |
|
|
|
self.class_in_dim = class_in_dim |
|
|
|
|
|
|
|
def get_loss(self, pred, target, seq_len=None): |
|
|
|
if seq_len is not None and target.dim()>1: |
|
|
|
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) |
|
|
|
target = target.masked_fill(mask, self.padding_idx) |
|
|
|
|
|
|
|
if pred.dim() > 2: |
|
|
|
if self.class_in_dim == -1: |
|
|
|
if pred.size(1) != target.size(1): # 有可能顺序替换了 |
|
|
|
pred = pred.transpose(1, 2) |
|
|
|
else: |
|
|
|
pred = pred.tranpose(-1, pred) |
|
|
|
pred = pred.transpose(1, 2) |
|
|
|
pred = pred.reshape(-1, pred.size(-1)) |
|
|
|
target = target.reshape(-1) |
|
|
|
if seq_len is not None and target.dim()>1: |
|
|
|
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) |
|
|
|
target = target.masked_fill(mask, self.padding_idx) |
|
|
|
|
|
|
|
return F.cross_entropy(input=pred, target=target, |
|
|
|
ignore_index=self.padding_idx, reduction=self.reduction) |
|
|
|