from tqdm import tqdm import os import math import logging from hparams import Hparams from hetu_transformer import Transformer from data_load import DataLoader import hetu as ht import numpy as np # import time logging.basicConfig(level=logging.INFO) logging.info("# hparams") hparams = Hparams() parser = hparams.parser hp = parser.parse_args() print(hp) logging.info("# Prepare train/eval batches") dataloader = DataLoader(hp.train1, hp.train2, hp.maxlen1, hp.maxlen2, hp.vocab) ctx = ht.gpu(1) xs = ht.Variable(name='xs') ys1 = ht.Variable(name='ys1') ys2 = ht.Variable(name='ys2') nonpadding = ht.Variable(name='nonpadding') logging.info("# Load model") m = Transformer(hp) loss = m.train(xs, (ys1, ys2)) loss = ht.div_op(ht.reduce_sum_op(loss * nonpadding, axes=[0, 1]), ht.reduce_sum_op(nonpadding, axes=[0, 1]) + 1e-7) opt = ht.optim.SGDOptimizer(hp.lr) train_op = opt.minimize(loss) executor = ht.Executor([loss, train_op], ctx=ctx) logging.info("# Session") for ep in range(hp.num_epochs): dataloader.make_epoch_data(hp.batch_size) for i in tqdm(range(dataloader.batch_num)): xs_val, ys_val = dataloader.get_batch() # st = time.time() xs_val = xs_val[0] ys1_val = ys_val[0][:, :-1] ys2_val = ys_val[0][:, 1:] nonpadding_val = np.not_equal( ys2_val, dataloader.get_pad()).astype(np.float32) _loss, _ = executor.run( feed_dict={xs: xs_val, ys1: ys1_val, ys2: ys2_val, nonpadding: nonpadding_val}) # en = time.time() # if i == 100: # exit() log_str = 'Iteration %d, loss %f' % (i, _loss.asnumpy()) print(log_str) # print('time: ', (en - st)) logging.info("Done")