You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_hetu_transformer.py 1.8 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from tqdm import tqdm
  2. import os
  3. import math
  4. import logging
  5. from hparams import Hparams
  6. from hetu_transformer import Transformer
  7. from data_load import DataLoader
  8. import hetu as ht
  9. import numpy as np
  10. # import time
  11. logging.basicConfig(level=logging.INFO)
  12. logging.info("# hparams")
  13. hparams = Hparams()
  14. parser = hparams.parser
  15. hp = parser.parse_args()
  16. print(hp)
  17. logging.info("# Prepare train/eval batches")
  18. dataloader = DataLoader(hp.train1, hp.train2, hp.maxlen1, hp.maxlen2, hp.vocab)
  19. ctx = ht.gpu(1)
  20. xs = ht.Variable(name='xs')
  21. ys1 = ht.Variable(name='ys1')
  22. ys2 = ht.Variable(name='ys2')
  23. nonpadding = ht.Variable(name='nonpadding')
  24. logging.info("# Load model")
  25. m = Transformer(hp)
  26. loss = m.train(xs, (ys1, ys2))
  27. loss = ht.div_op(ht.reduce_sum_op(loss * nonpadding,
  28. axes=[0, 1]), ht.reduce_sum_op(nonpadding, axes=[0, 1]) + 1e-7)
  29. opt = ht.optim.SGDOptimizer(hp.lr)
  30. train_op = opt.minimize(loss)
  31. executor = ht.Executor([loss, train_op], ctx=ctx)
  32. logging.info("# Session")
  33. for ep in range(hp.num_epochs):
  34. dataloader.make_epoch_data(hp.batch_size)
  35. for i in tqdm(range(dataloader.batch_num)):
  36. xs_val, ys_val = dataloader.get_batch()
  37. # st = time.time()
  38. xs_val = xs_val[0]
  39. ys1_val = ys_val[0][:, :-1]
  40. ys2_val = ys_val[0][:, 1:]
  41. nonpadding_val = np.not_equal(
  42. ys2_val, dataloader.get_pad()).astype(np.float32)
  43. _loss, _ = executor.run(
  44. feed_dict={xs: xs_val, ys1: ys1_val, ys2: ys2_val, nonpadding: nonpadding_val})
  45. # en = time.time()
  46. # if i == 100:
  47. # exit()
  48. log_str = 'Iteration %d, loss %f' % (i, _loss.asnumpy())
  49. print(log_str)
  50. # print('time: ', (en - st))
  51. logging.info("Done")