import hetu as ht from hetu.launcher import launch import os import os.path as osp import numpy as np import yaml import time import argparse from tqdm import tqdm from sklearn import metrics def worker(args): def train(iterations, auc_enabled=True, tqdm_enabled=False): localiter = tqdm(range(iterations) ) if tqdm_enabled else range(iterations) train_loss = [] train_acc = [] if auc_enabled: train_auc = [] for it in localiter: loss_val, predict_y, y_val, _ = executor.run( 'train', convert_to_numpy_ret_vals=True) if y_val.shape[1] == 1: # for criteo case acc_val = np.equal( y_val, predict_y > 0.5).astype(np.float32) else: acc_val = np.equal( np.argmax(y_val, 1), np.argmax(predict_y, 1)).astype(np.float32) train_loss.append(loss_val[0]) train_acc.append(acc_val) if auc_enabled: train_auc.append(metrics.roc_auc_score(y_val, predict_y)) if auc_enabled: return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc) else: return np.mean(train_loss), np.mean(train_acc) def validate(iterations, tqdm_enabled=False): localiter = tqdm(range(iterations) ) if tqdm_enabled else range(iterations) test_loss = [] test_acc = [] test_auc = [] for it in localiter: loss_val, test_y_predicted, y_test_val = executor.run( 'validate', convert_to_numpy_ret_vals=True) if y_test_val.shape[1] == 1: # for criteo case correct_prediction = np.equal( y_test_val, test_y_predicted > 0.5).astype(np.float32) else: correct_prediction = np.equal( np.argmax(y_test_val, 1), np.argmax(test_y_predicted, 1)).astype(np.float32) test_loss.append(loss_val[0]) test_acc.append(correct_prediction) test_auc.append(metrics.roc_auc_score( y_test_val, test_y_predicted)) return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc) def get_current_shard(data): if args.comm is not None: part_size = data.shape[0] // nrank start = part_size * rank end = start + part_size if rank != nrank - 1 else data.shape[0] return data[start:end] else: return data batch_size = 128 dataset = args.dataset model = args.model device_id = 0 if args.comm == 'PS': rank = ht.get_worker_communicate().rank() nrank = int(os.environ['DMLC_NUM_WORKER']) device_id = rank % 8 elif args.comm == 'Hybrid': comm = ht.wrapped_mpi_nccl_init() device_id = comm.dev_id rank = comm.rank nrank = int(os.environ['DMLC_NUM_WORKER']) if dataset == 'criteo': # define models for criteo if args.all: from models.load_data import process_all_criteo_data dense, sparse, labels = process_all_criteo_data( return_val=args.val) elif args.val: from models.load_data import process_head_criteo_data dense, sparse, labels = process_head_criteo_data(return_val=True) else: from models.load_data import process_sampled_criteo_data dense, sparse, labels = process_sampled_criteo_data() if isinstance(dense, tuple): dense_input = ht.dataloader_op([[get_current_shard(dense[0]), batch_size, 'train'], [ get_current_shard(dense[1]), batch_size, 'validate']]) sparse_input = ht.dataloader_op([[get_current_shard(sparse[0]), batch_size, 'train'], [ get_current_shard(sparse[1]), batch_size, 'validate']]) y_ = ht.dataloader_op([[get_current_shard(labels[0]), batch_size, 'train'], [ get_current_shard(labels[1]), batch_size, 'validate']]) else: dense_input = ht.dataloader_op( [[get_current_shard(dense), batch_size, 'train']]) sparse_input = ht.dataloader_op( [[get_current_shard(sparse), batch_size, 'train']]) y_ = ht.dataloader_op( [[get_current_shard(labels), batch_size, 'train']]) elif dataset == 'adult': from models.load_data import load_adult_data x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_adult_data() dense_input = [ ht.dataloader_op([ [get_current_shard(x_train_deep[:, i]), batch_size, 'train'], [get_current_shard(x_test_deep[:, i]), batch_size, 'validate'], ]) for i in range(12) ] sparse_input = ht.dataloader_op([ [get_current_shard(x_train_wide), batch_size, 'train'], [get_current_shard(x_test_wide), batch_size, 'validate'], ]) y_ = ht.dataloader_op([ [get_current_shard(y_train), batch_size, 'train'], [get_current_shard(y_test), batch_size, 'validate'], ]) else: raise NotImplementedError print("Data loaded.") loss, prediction, y_, train_op = model(dense_input, sparse_input, y_) eval_nodes = {'train': [loss, prediction, y_, train_op]} if args.val: print('Validation enabled...') eval_nodes['validate'] = [loss, prediction, y_] executor_log_path = osp.join(osp.dirname(osp.abspath(__file__)), 'logs') executor = ht.Executor(eval_nodes, ctx=ht.gpu(device_id), comm_mode=args.comm, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123, log_path=executor_log_path) if args.all and dataset == 'criteo': print('Processing all data...') file_path = '%s_%s' % ({None: 'local', 'PS': 'ps', 'Hybrid': 'hybrid'}[ args.comm], args.raw_model) file_path += '%d.log' % rank if args.comm else '.log' file_path = osp.join(osp.dirname( osp.abspath(__file__)), 'logs', file_path) log_file = open(file_path, 'w') total_epoch = args.nepoch if args.nepoch > 0 else 11 for ep in range(total_epoch): print("ep: %d" % ep) ep_st = time.time() train_loss, train_acc, train_auc = train(executor.get_batch_num( 'train') // 10 + (ep % 10 == 9) * (executor.get_batch_num('train') % 10), tqdm_enabled=True) ep_en = time.time() if args.val: val_loss, val_acc, val_auc = validate( executor.get_batch_num('validate')) printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f, train_time: %.4f"\ % (train_loss, train_acc, train_auc, val_loss, val_acc, val_auc, ep_en - ep_st) else: printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\ % (train_loss, train_acc, train_auc, ep_en - ep_st) print(printstr) log_file.write(printstr + '\n') log_file.flush() else: total_epoch = args.nepoch if args.nepoch > 0 else 50 for ep in range(total_epoch): if ep == 5: start = time.time() print("epoch %d" % ep) ep_st = time.time() train_loss, train_acc = train( executor.get_batch_num('train'), auc_enabled=False) ep_en = time.time() if args.val: val_loss, val_acc, val_auc = validate( executor.get_batch_num('validate')) print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f" % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc, val_auc)) else: print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f" % (train_loss, train_acc, ep_en - ep_st)) print('all time:', time.time() - start) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True, help="model to be tested") parser.add_argument("--val", action="store_true", help="whether to use validation") parser.add_argument("--all", action="store_true", help="whether to use all data") parser.add_argument("--comm", default=None, help="whether to use distributed setting, can be None, AllReduce, PS, Hybrid") parser.add_argument("--bsp", action="store_true", help="whether to use bsp instead of asp") parser.add_argument("--cache", default=None, help="cache policy") parser.add_argument("--bound", default=100, help="cache bound") parser.add_argument("--config", type=str, default=osp.join(osp.dirname( osp.abspath(__file__)), "./settings/local_s1_w4.yml"), help="configuration for ps") parser.add_argument("--nepoch", type=int, default=-1, help="num of epochs, each train 1/10 data") args = parser.parse_args() import models print('Model:', args.model) model = eval('models.' + args.model) args.dataset = args.model.split('_')[-1] args.raw_model = args.model args.model = model if args.comm is None: worker(args) elif args.comm == 'Hybrid': settings = yaml.load(open(args.config).read(), Loader=yaml.FullLoader) value = settings['shared'] os.environ['DMLC_ROLE'] = 'worker' for k, v in value.items(): os.environ[k] = str(v) worker(args) elif args.comm == 'PS': launch(worker, args) else: raise NotImplementedError