From 4fd49cc333fd8e571e220169e376346c720f3293 Mon Sep 17 00:00:00 2001 From: xuyige Date: Thu, 11 Apr 2019 15:00:10 +0800 Subject: [PATCH 1/3] add sigmoid activate function in MLP --- fastNLP/modules/decoder/MLP.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index d75f6b48..3a793f24 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -36,6 +36,7 @@ class MLP(nn.Module): actives = { 'relu': nn.ReLU(), 'tanh': nn.Tanh(), + 'sigmoid': nn.Sigmoid(), } if not isinstance(activation, list): activation = [activation] * (len(size_layer) - 2) From 6f010d488db843816a02a81c71c1e290c3077a1a Mon Sep 17 00:00:00 2001 From: Yunfan Shao Date: Sun, 14 Apr 2019 21:11:16 +0800 Subject: [PATCH 2/3] update readme --- reproduction/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reproduction/README.md b/reproduction/README.md index 1c93c6bc..8d14d36d 100644 --- a/reproduction/README.md +++ b/reproduction/README.md @@ -8,7 +8,7 @@ ## Star-Transformer [reference](https://arxiv.org/abs/1902.09113) -### Performance +### Performance (still in progress) |任务| 数据集 | SOTA | 模型表现 | |------|------| ------| ------| |Pos Tagging|CTB 9.0|-|ACC 92.31| From 16fdf20d2630df580e4f5b7af244c66b048f70f3 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 14 Apr 2019 22:20:39 +0800 Subject: [PATCH 3/3] support parallel loss --- fastNLP/core/losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9b8b8d8f..b52244e5 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -251,7 +251,8 @@ class LossInForward(LossBase): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") - raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") + loss = torch.sum(loss) / (loss.view(-1)).size(0) + # raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") return loss