Browse Source

support parallel loss

tags/v0.4.10
xuyige 5 years ago
parent
commit
16fdf20d26
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      fastNLP/core/losses.py

+ 2
- 1
fastNLP/core/losses.py View File

@@ -251,7 +251,8 @@ class LossInForward(LossBase):
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}")
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")
loss = torch.sum(loss) / (loss.view(-1)).size(0)
# raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")


return loss return loss




Loading…
Cancel
Save