You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_wdl_adult.py 2.5 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import tensorflow as tf
  2. import numpy as np
  3. def wdl_adult(X_deep, X_wide, y_, cluster=None, task_id=None):
  4. lr_ = 5 / 128
  5. dim_wide = 809
  6. dim_deep = 68
  7. use_ps = cluster is not None
  8. if use_ps:
  9. device = tf.device(tf.train.replica_device_setter(
  10. worker_device="/job:worker/task:%d/gpu:0" % (task_id),
  11. cluster=cluster))
  12. else:
  13. device = tf.device('/gpu:0')
  14. global_step = tf.Variable(0, name="global_step", trainable=False)
  15. with device:
  16. if use_ps:
  17. global_step = tf.Variable(0, name="global_step", trainable=False)
  18. rand = np.random.RandomState(seed=123)
  19. W = tf.Variable(rand.normal(scale=0.1, size=[
  20. dim_wide+20, 2]), dtype=tf.float32)
  21. W1 = tf.Variable(rand.normal(scale=0.1, size=[
  22. dim_deep, 50]), dtype=tf.float32)
  23. b1 = tf.Variable(rand.normal(scale=0.1, size=[50]), dtype=tf.float32)
  24. W2 = tf.Variable(rand.normal(
  25. scale=0.1, size=[50, 20]), dtype=tf.float32)
  26. b2 = tf.Variable(rand.normal(scale=0.1, size=[20]), dtype=tf.float32)
  27. Embedding = []
  28. for i in range(8):
  29. Embedding.append(tf.Variable(rand.normal(
  30. scale=0.1, size=[20, 8]), dtype=tf.float32))
  31. # deep
  32. X_deep_input = None
  33. for i in range(8):
  34. now = tf.nn.embedding_lookup(Embedding[i], X_deep[i])
  35. now = tf.reshape(now, (-1, 8))
  36. if X_deep_input is None:
  37. X_deep_input = now
  38. else:
  39. X_deep_input = tf.concat([X_deep_input, now], 1)
  40. for i in range(4):
  41. now = tf.reshape(X_deep[i + 8], (-1, 1))
  42. X_deep_input = tf.concat([X_deep_input, now], 1)
  43. mat1 = tf.matmul(X_deep_input, W1)
  44. add1 = tf.add(mat1, b1)
  45. relu1 = tf.nn.relu(add1)
  46. dropout1 = relu1
  47. mat2 = tf.matmul(dropout1, W2)
  48. add2 = tf.add(mat2, b2)
  49. relu2 = tf.nn.relu(add2)
  50. dropout2 = relu2
  51. dmodel = dropout2
  52. # wide
  53. wmodel = tf.concat([X_wide, dmodel], 1)
  54. wmodel = tf.matmul(wmodel, W)
  55. y = wmodel
  56. loss = tf.reduce_mean(
  57. tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
  58. )
  59. optimizer = tf.train.GradientDescentOptimizer(lr_)
  60. train_op = optimizer.minimize(loss, global_step=global_step)
  61. if use_ps:
  62. return loss, y, train_op, global_step
  63. else:
  64. return loss, y, train_op

分布式深度学习系统

Contributors (1)