diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9b8b8d8f..b52244e5 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -251,7 +251,8 @@ class LossInForward(LossBase): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") - raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") + loss = torch.sum(loss) / (loss.view(-1)).size(0) + # raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") return loss