You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_tf_transformer.py 3.5 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import tensorflow as tf
  2. from tqdm import tqdm
  3. import os
  4. import math
  5. import logging
  6. from hparams import Hparams
  7. from tf_transformer import Transformer
  8. from data_load import DataLoader
  9. # import time
  10. logging.basicConfig(level=logging.INFO)
  11. logging.info("# hparams")
  12. hparams = Hparams()
  13. parser = hparams.parser
  14. hp = parser.parse_args()
  15. print(hp)
  16. # save_hparams(hp, hp.logdir)
  17. logging.info("# Prepare train/eval batches")
  18. dataloader = DataLoader(hp.train1, hp.train2, hp.maxlen1, hp.maxlen2, hp.vocab)
  19. xs = tf.placeholder(name='xs', dtype=tf.int32, shape=[16, 100])
  20. ys1 = tf.placeholder(name='ys1', dtype=tf.int32, shape=[16, 99])
  21. ys2 = tf.placeholder(name='ys2', dtype=tf.int32, shape=[16, 99])
  22. logging.info("# Load model")
  23. m = Transformer(hp)
  24. loss = m.train(xs, (ys1, ys2))
  25. nonpadding = tf.to_float(tf.not_equal(ys2, dataloader.get_pad())) # 0: <pad>
  26. loss = tf.reduce_sum(loss * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
  27. global_step = tf.train.get_or_create_global_step()
  28. optimizer = tf.train.GradientDescentOptimizer(hp.lr)
  29. train_op = optimizer.minimize(loss, global_step=global_step)
  30. # y_hat, eval_summaries = m.eval(xs, ys)
  31. # y_hat = m.infer(xs, ys)
  32. logging.info("# Session")
  33. saver = tf.train.Saver(max_to_keep=hp.num_epochs)
  34. with tf.Session() as sess:
  35. ckpt = tf.train.latest_checkpoint(hp.logdir)
  36. if ckpt is None:
  37. logging.info("Initializing from scratch")
  38. sess.run(tf.global_variables_initializer())
  39. # save_variable_specs(os.path.join(hp.logdir, "specs"))
  40. else:
  41. saver.restore(sess, ckpt)
  42. _gs = sess.run(global_step)
  43. for ep in range(hp.num_epochs):
  44. dataloader.make_epoch_data(hp.batch_size)
  45. for i in tqdm(range(dataloader.batch_num)):
  46. xs_val, ys_val = dataloader.get_batch()
  47. # st = time.time()
  48. _loss, _, _gs = sess.run([loss, train_op, global_step], feed_dict={
  49. xs: xs_val[0], ys1: ys_val[0][:, :-1], ys2: ys_val[0][:, 1:]})
  50. # en = time.time()
  51. # if i == 100:
  52. # exit()
  53. # epoch = math.ceil(_gs / num_train_batches)
  54. log_str = 'Iteration %d, loss %f' % (i, _loss)
  55. print(log_str)
  56. # print('time: ', (en - st))
  57. # logging.info("epoch {} is done".format(ep))
  58. # _loss = sess.run(loss) # train loss
  59. # logging.info("# test evaluation")
  60. # _, _eval_summaries = sess.run([eval_init_op, eval_summaries])
  61. # summary_writer.add_summary(_eval_summaries, _gs)
  62. # logging.info("# get hypotheses")
  63. # hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token)
  64. # logging.info("# write results")
  65. # model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss)
  66. # if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
  67. # translation = os.path.join(hp.evaldir, model_output)
  68. # with open(translation, 'w') as fout:
  69. # fout.write("\n".join(hypotheses))
  70. # logging.info("# calc bleu score and append it to translation")
  71. # calc_bleu(hp.eval3, translation)
  72. # logging.info("# save models")
  73. # ckpt_name = os.path.join(hp.logdir, model_output)
  74. # saver.save(sess, ckpt_name, global_step=_gs)
  75. # logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name))
  76. # logging.info("# fall back to train mode")
  77. logging.info("Done")