You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_wdl_criteo.py 1.9 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940
  1. import tensorflow as tf
  2. def wdl_criteo(dense_input, sparse_input, y_, partitioner=None, part_all=True, param_on_gpu=True):
  3. feature_dimension = 33762577
  4. embedding_size = 128
  5. learning_rate = 0.01 / 8 # here to comply with HETU
  6. all_partitioner, embed_partitioner = (
  7. partitioner, None) if part_all else (None, partitioner)
  8. with tf.compat.v1.variable_scope('wdl', dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.01), partitioner=all_partitioner):
  9. with tf.device('/cpu:0'):
  10. Embedding = tf.compat.v1.get_variable(name="Embedding", shape=(
  11. feature_dimension, embedding_size), partitioner=embed_partitioner)
  12. sparse_input_embedding = tf.nn.embedding_lookup(
  13. Embedding, sparse_input)
  14. device = '/gpu:0' if param_on_gpu else '/cpu:0'
  15. with tf.device(device):
  16. W1 = tf.compat.v1.get_variable(name='W1', shape=[13, 256])
  17. W2 = tf.compat.v1.get_variable(name='W2', shape=[256, 256])
  18. W3 = tf.compat.v1.get_variable(name='W3', shape=[256, 256])
  19. W4 = tf.compat.v1.get_variable(
  20. name='W4', shape=[256 + 26 * embedding_size, 1])
  21. with tf.device('/gpu:0'):
  22. sparse_input_embedding = tf.reshape(
  23. sparse_input_embedding, (-1, 26*embedding_size))
  24. flatten = dense_input
  25. fc1 = tf.matmul(flatten, W1)
  26. relu1 = tf.nn.relu(fc1)
  27. fc2 = tf.matmul(relu1, W2)
  28. relu2 = tf.nn.relu(fc2)
  29. y3 = tf.matmul(relu2, W3)
  30. y4 = tf.concat((sparse_input_embedding, y3), 1)
  31. y = tf.matmul(y4, W4)
  32. loss = tf.reduce_mean(
  33. tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=y_))
  34. optimizer = tf.compat.v1.train.GradientDescentOptimizer(
  35. learning_rate)
  36. return loss, y, optimizer