@@ -251,7 +251,8 @@ class LossInForward(LossBase): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | ||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
loss = torch.sum(loss) / (loss.view(-1)).size(0) | |||||
# raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
return loss | return loss | ||||
@@ -36,6 +36,7 @@ class MLP(nn.Module): | |||||
actives = { | actives = { | ||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
'sigmoid': nn.Sigmoid(), | |||||
} | } | ||||
if not isinstance(activation, list): | if not isinstance(activation, list): | ||||
activation = [activation] * (len(size_layer) - 2) | activation = [activation] * (len(size_layer) - 2) | ||||
@@ -8,7 +8,7 @@ | |||||
## Star-Transformer | ## Star-Transformer | ||||
[reference](https://arxiv.org/abs/1902.09113) | [reference](https://arxiv.org/abs/1902.09113) | ||||
### Performance | |||||
### Performance (still in progress) | |||||
|任务| 数据集 | SOTA | 模型表现 | | |任务| 数据集 | SOTA | 模型表现 | | ||||
|------|------| ------| ------| | |------|------| ------| ------| | ||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |Pos Tagging|CTB 9.0|-|ACC 92.31| | ||||