|
- import os
- import numpy as np
- import tensorflow as tf
- import time
- import argparse
- from tqdm import tqdm
- from sklearn import metrics
-
- from autodist import AutoDist
- from autodist.resource_spec import ResourceSpec
- from autodist.strategy import PS, PSLoadBalancing, PartitionedPS, AllReduce, Parallax
- from autodist.strategy.base import Strategy
- from autodist.kernel.common.utils import get_op_name
- from tensorflow.python.framework import ops
-
-
- def pop_env():
- for k in ['https_proxy', 'http_proxy']:
- if k in os.environ:
- os.environ.pop(k)
-
-
- pop_env()
-
- # Please DO NOT modify /etc/bash.bashrc to activate conda environment.
- # Use python_venv in spec yml file instead.
- # Use absolute path of python file.
- # Here we use the tf native partitioner instead of autodist's PartitionPS.
-
-
- class Parallaxx(PSLoadBalancing, AllReduce):
- """
- Modify original parallax to remove replica on CPUs.
- """
-
- def __init__(self, chunk_size=128, local_proxy_variable=False, sync=True, staleness=0):
- PSLoadBalancing.__init__(self, local_proxy_variable, sync, staleness)
- AllReduce.__init__(self, chunk_size)
-
- # pylint: disable=attribute-defined-outside-init
- def build(self, graph_item, resource_spec):
- """Generate the strategy."""
- expr = Strategy()
-
- # For each variable, generate variable synchronizer config
- expr.graph_config.replicas.extend(
- [k for k, v in resource_spec.gpu_devices])
- reduction_device_names = [k for k, _ in resource_spec.cpu_devices]
- self.loads = {ps: 0.0 for ps in reduction_device_names}
-
- # Generate node config
- node_config = []
- for idx, var in enumerate(graph_item.trainable_var_op_to_var.values()):
- var_op_name = get_op_name(var.name)
- grad, _, _ = graph_item.var_op_name_to_grad_info[var_op_name]
- if isinstance(grad, ops.Tensor): # this is a dense variable
- group_id = idx // self.chunk_size
- config = self._gen_all_reduce_node_config(
- var.name, group=group_id)
- else: # sparse updates
- # For Parallax Strategy, all PS vars are sparse so we don't use a proxy.
- # Sparse variables are likely larger, so keeping copies would be costlier,
- # and usually each device only requires a small part of the overall variable.
- config = self._gen_ps_node_config(
- var,
- # For Parallax Strategy, all PS vars are sparse which does not need proxy.
- False,
- self._sync,
- self._staleness
- )
- node_config.append(config)
- expr.node_config.extend(node_config)
-
- return expr
-
-
- def train_criteo(model, args):
- resource_spec_file = os.path.join(os.path.dirname(
- __file__), 'settings', 'plx_local_spec.yml')
- autodist = AutoDist(resource_spec_file, Parallaxx())
- respec = ResourceSpec(resource_spec_file)
- if args.all:
- from models.load_data import process_all_criteo_data
- dense, sparse, all_labels = process_all_criteo_data()
- dense_feature, val_dense = dense
- sparse_feature, val_sparse = sparse
- labels, val_labels = all_labels
- else:
- from models.load_data import process_sampled_criteo_data
- dense_feature, sparse_feature, labels = process_sampled_criteo_data()
-
- # autodist will split the feeding data
- batch_size = 128
- with tf.Graph().as_default() as g, autodist.scope():
- dense_input = tf.compat.v1.placeholder(tf.float32, [batch_size, 13])
- sparse_input = tf.compat.v1.placeholder(tf.int32, [batch_size, 26])
- y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 1])
- embed_partitioner = tf.fixed_size_partitioner(
- len(respec.nodes), 0) if len(respec.nodes) > 1 else None
- loss, y, opt = model(dense_input, sparse_input,
- y_, embed_partitioner, False)
- train_op = opt.minimize(loss)
-
- sess = autodist.create_distributed_session()
-
- my_feed_dict = {
- dense_input: np.empty(shape=(batch_size, 13)),
- sparse_input: np.empty(shape=(batch_size, 26)),
- y_: np.empty(shape=(batch_size, 1)),
- }
-
- if args.all:
- raw_log_file = os.path.join(os.path.split(os.path.abspath(__file__))[
- 0], 'logs', 'tf_plx_%s.log' % (args.model))
- print('Processing all data, log to', raw_log_file)
- log_file = open(raw_log_file, 'w')
- iterations = dense_feature.shape[0] // batch_size
- total_epoch = 11
- start_index = 0
- for ep in range(total_epoch):
- print("epoch %d" % ep)
- st_time = time.time()
- train_loss, train_acc, train_auc = [], [], []
- for it in tqdm(range(iterations // 10 + (ep % 10 == 9) * (iterations % 10))):
- my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
- my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
- my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
- start_index += batch_size
- if start_index + batch_size > dense_feature.shape[0]:
- start_index = 0
- loss_val = sess.run(
- [loss, y, y_, train_op], feed_dict=my_feed_dict)
- pred_val = loss_val[1]
- true_val = loss_val[2]
- acc_val = np.equal(
- true_val,
- pred_val > 0.5)
- train_loss.append(loss_val[0])
- train_acc.append(acc_val)
- train_auc.append(metrics.roc_auc_score(true_val, pred_val))
- tra_accuracy = np.mean(train_acc)
- tra_loss = np.mean(train_loss)
- tra_auc = np.mean(train_auc)
- en_time = time.time()
- train_time = en_time - st_time
- printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
- % (tra_loss, tra_accuracy, tra_auc, train_time)
- print(printstr)
- log_file.write(printstr + '\n')
- log_file.flush()
-
- else:
- iteration = dense_feature.shape[0] // batch_size
-
- epoch = 50
- for ep in range(epoch):
- print('epoch', ep)
- if ep == 5:
- start = time.time()
- ep_st = time.time()
- train_loss = []
- train_acc = []
- for idx in range(iteration):
- start_index = idx * batch_size
- my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
- my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
- my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
-
- loss_val = sess.run(
- [loss, y, y_, train_op], feed_dict=my_feed_dict)
- pred_val = loss_val[1]
- true_val = loss_val[2]
- if pred_val.shape[1] == 1: # for criteo case
- acc_val = np.equal(
- true_val,
- pred_val > 0.5)
- else:
- acc_val = np.equal(
- np.argmax(pred_val, 1),
- np.argmax(true_val, 1)).astype(np.float32)
- train_loss.append(loss_val[0])
- train_acc.append(acc_val)
- tra_accuracy = np.mean(train_acc)
- tra_loss = np.mean(train_loss)
- ep_en = time.time()
- print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
- % (tra_loss, tra_accuracy, ep_en - ep_st))
- print('all time:', (time.time() - start))
-
-
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--model", type=str, required=True,
- help="model to be tested")
- parser.add_argument("--all", action="store_true",
- help="whether to use all data")
- args = parser.parse_args()
- raw_model = args.model
- import tf_models
- model = eval('tf_models.' + raw_model)
- dataset = raw_model.split('_')[-1]
- print('Model:', raw_model)
-
- if dataset == 'criteo':
- train_criteo(model, args)
- else:
- raise NotImplementedError
-
-
- if __name__ == '__main__':
- main()
|