|
@@ -132,7 +132,7 @@ class TransformerSeq2SeqEncoder(Seq2SeqEncoder): |
|
|
x = self.input_fc(x) |
|
|
x = self.input_fc(x) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
encoder_mask = seq_len_to_mask(seq_len) |
|
|
|
|
|
|
|
|
encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) |
|
|
encoder_mask = encoder_mask.to(device) |
|
|
encoder_mask = encoder_mask.to(device) |
|
|
|
|
|
|
|
|
for layer in self.layer_stacks: |
|
|
for layer in self.layer_stacks: |
|
|