|
|
@@ -25,11 +25,11 @@ def prepare_env(): |
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-2) |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=5e-3) |
|
|
|
mask = seq_len_to_mask(tgt_seq_len).eq(0) |
|
|
|
target = tgt_words_idx.masked_fill(mask, -100) |
|
|
|
|
|
|
|
for i in range(100): |
|
|
|
for i in range(50): |
|
|
|
optimizer.zero_grad() |
|
|
|
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size |
|
|
|
loss = F.cross_entropy(pred.transpose(1, 2), target) |
|
|
|