import hetu as ht from hetu.launcher import launch import os import numpy as np import yaml import time import argparse from tqdm import tqdm from sklearn import metrics from models import load_data, wdl_adult def worker(args): def train(iterations, auc_enabled=True, tqdm_enabled=False): localiter = tqdm(range(iterations) ) if tqdm_enabled else range(iterations) train_loss = [] train_acc = [] if auc_enabled: train_auc = [] for it in localiter: loss_val, predict_y, y_val, _ = executor.run( 'train', convert_to_numpy_ret_vals=True) acc_val = np.equal( np.argmax(y_val, 1), np.argmax(predict_y, 1)).astype(np.float32) train_loss.append(loss_val[0]) train_acc.append(acc_val) if auc_enabled: train_auc.append(metrics.roc_auc_score(y_val, predict_y)) if auc_enabled: return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc) else: return np.mean(train_loss), np.mean(train_acc) def validate(iterations, tqdm_enabled=False): localiter = tqdm(range(iterations) ) if tqdm_enabled else range(iterations) test_loss = [] test_acc = [] test_auc = [] for it in localiter: loss_val, test_y_predicted, y_test_val = executor.run( 'validate', convert_to_numpy_ret_vals=True) correct_prediction = np.equal( np.argmax(y_test_val, 1), np.argmax(test_y_predicted, 1)).astype(np.float32) test_loss.append(loss_val[0]) test_acc.append(correct_prediction) test_auc.append(metrics.roc_auc_score( y_test_val, test_y_predicted)) return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc) batch_size = 128 ctx = { 'local': 'gpu:0', 'lps': 'cpu:0,gpu:0,gpu:1,gpu:2,gpu:7', 'lhy': 'cpu:0,gpu:1,gpu:2,gpu:3,gpu:6', 'rps': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3', 'rhy': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3' }[args.config] dense_param_ctx = {'local': 'gpu:0', 'lps': 'cpu:0,gpu:0,gpu:1,gpu:2,gpu:7', 'lhy': 'gpu:1,gpu:2,gpu:3,gpu:6', 'rps': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3', 'rhy': 'daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3'}[args.config] with ht.context(ctx): x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_data.load_adult_data() dense_input = [ ht.dataloader_op([ [x_train_deep[:, i], batch_size, 'train'], [x_test_deep[:, i], batch_size, 'validate'], ]) for i in range(12) ] sparse_input = ht.dataloader_op([ [x_train_wide, batch_size, 'train'], [x_test_wide, batch_size, 'validate'], ]) y_ = ht.dataloader_op([ [y_train, batch_size, 'train'], [y_test, batch_size, 'validate'], ]) print("Data loaded.") loss, prediction, y_, train_op = wdl_adult.wdl_adult( dense_input, sparse_input, y_, dense_param_ctx) eval_nodes = {'train': [loss, prediction, y_, train_op]} if args.val: print('Validation enabled...') eval_nodes['validate'] = [loss, prediction, y_] executor = ht.Executor(eval_nodes, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123) total_epoch = args.nepoch if args.nepoch > 0 else 50 for ep in range(total_epoch): if ep == 5: start = time.time() print("epoch %d" % ep) ep_st = time.time() train_loss, train_acc = train( executor.get_batch_num('train'), auc_enabled=False) ep_en = time.time() if args.val: val_loss, val_acc, val_auc = validate( executor.get_batch_num('validate')) print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f" % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc, val_auc)) else: print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f" % (train_loss, train_acc, ep_en - ep_st)) print('all time:', time.time() - start) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='local', help='[local, lps(localps), lhy(localhybrid), rps(remoteps), rhy]') parser.add_argument("--val", action="store_true", help="whether to use validation") parser.add_argument("--all", action="store_true", help="whether to use all data") parser.add_argument("--bsp", action="store_true", help="whether to use bsp instead of asp") parser.add_argument("--cache", default=None, help="cache policy") parser.add_argument("--bound", default=100, help="cache bound") parser.add_argument("--nepoch", type=int, default=-1, help="num of epochs, each train 1/10 data") args = parser.parse_args() worker(args)