|
@@ -251,7 +251,8 @@ class LossInForward(LossBase): |
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") |
|
|
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") |
|
|
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") |
|
|
|
|
|
|
|
|
loss = torch.sum(loss) / (loss.view(-1)).size(0) |
|
|
|
|
|
# raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") |
|
|
|
|
|
|
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|