diff --git a/fastNLP/modules/encoder/masked_rnn.py b/fastNLP/modules/encoder/masked_rnn.py index 17ebcfd6..76f828a9 100644 --- a/fastNLP/modules/encoder/masked_rnn.py +++ b/fastNLP/modules/encoder/masked_rnn.py @@ -273,7 +273,7 @@ class MaskedRNNBase(nn.Module): hx = (hx, hx) func = AutogradMaskedStep(num_layers=self.num_layers, - dropout=self.dropout, + dropout=self.step_dropout, train=self.training, lstm=lstm)