From 16fdf20d2630df580e4f5b7af244c66b048f70f3 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 14 Apr 2019 22:20:39 +0800 Subject: [PATCH] support parallel loss --- fastNLP/core/losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9b8b8d8f..b52244e5 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -251,7 +251,8 @@ class LossInForward(LossBase): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") - raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") + loss = torch.sum(loss) / (loss.view(-1)).size(0) + # raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") return loss