diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index 766dc225..0470e91b 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -4,12 +4,13 @@ from fastNLP.modules.utils import initial_parameter class MLP(nn.Module): - def __init__(self, size_layer, activation='relu', initial_method=None): + def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0): """Multilayer Perceptrons as a decoder :param size_layer: list of int, define the size of MLP layers. :param activation: str or function, the activation function for hidden layers. :param initial_method: str, the name of init method. + :param dropout: float, the probability of dropout. .. note:: There is no activation function applying on output layer. @@ -24,6 +25,8 @@ class MLP(nn.Module): else: self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i])) + self.dropout = nn.Dropout(p=dropout) + actives = { 'relu': nn.ReLU(), 'tanh': nn.Tanh(), @@ -38,8 +41,8 @@ class MLP(nn.Module): def forward(self, x): for layer in self.hiddens: - x = self.hidden_active(layer(x)) - x = self.output(x) + x = self.dropout(self.hidden_active(layer(x))) + x = self.dropout(self.output(x)) return x diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 12efe1c8..21497037 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -32,9 +32,9 @@ def initial_parameter(net, initial_method=None): elif initial_method == 'xavier_normal': init_method = init.xavier_normal_ elif initial_method == 'kaiming_normal' or initial_method == 'msra': - init_method = init.kaiming_normal + init_method = init.kaiming_normal_ elif initial_method == 'kaiming_uniform': - init_method = init.kaiming_normal + init_method = init.kaiming_uniform_ elif initial_method == 'orthogonal': init_method = init.orthogonal_ elif initial_method == 'sparse': @@ -42,7 +42,7 @@ def initial_parameter(net, initial_method=None): elif initial_method == 'normal': init_method = init.normal_ elif initial_method == 'uniform': - initial_method = init.uniform_ + init_method = init.uniform_ else: init_method = init.xavier_normal_