diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py index d280582a..5eae1e6d 100644 --- a/fastNLP/modules/encoder/seq2seq_encoder.py +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -132,7 +132,7 @@ class TransformerSeq2SeqEncoder(Seq2SeqEncoder): x = self.input_fc(x) x = F.dropout(x, p=self.dropout, training=self.training) - encoder_mask = seq_len_to_mask(seq_len) + encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) encoder_mask = encoder_mask.to(device) for layer in self.layer_stacks: