You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dcn_criteo.py 2.1 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. import time
  5. def cross_layer(x0, x1):
  6. # x0: input embedding feature (batch_size, 26 * embedding_size + 13)
  7. # x1: the output of last layer (batch_size, 26 * embedding_size + 13)
  8. embedding_len = 26 * 128 + 13
  9. weight = init.random_normal(
  10. shape=(embedding_len, 1), stddev=0.01, name='weight')
  11. bias = init.random_normal(shape=(embedding_len,), stddev=0.01, name='bias')
  12. x1w = ht.matmul_op(x1, weight) # (batch_size, 1)
  13. y = ht.mul_op(x0, ht.broadcastto_op(x1w, x0))
  14. y = y + x1 + ht.broadcastto_op(bias, y)
  15. return y
  16. def build_cross_layer(x0, num_layers=3):
  17. x1 = x0
  18. for i in range(num_layers):
  19. x1 = cross_layer(x0, x1)
  20. return x1
  21. def dcn_criteo(dense_input, sparse_input, y_):
  22. feature_dimension = 33762577
  23. embedding_size = 128
  24. learning_rate = 0.003
  25. Embedding = init.random_normal(
  26. [feature_dimension, embedding_size], stddev=0.01, name="snd_order_embedding", ctx=ht.cpu(0))
  27. sparse_input = ht.embedding_lookup_op(
  28. Embedding, sparse_input, ctx=ht.cpu(0))
  29. sparse_input = ht.array_reshape_op(sparse_input, (-1, 26*embedding_size))
  30. x = ht.concat_op(sparse_input, dense_input, axis=1)
  31. # Cross Network
  32. cross_output = build_cross_layer(x, num_layers=3)
  33. # DNN
  34. flatten = x
  35. W1 = init.random_normal(
  36. [26*embedding_size + 13, 256], stddev=0.01, name="W1")
  37. W2 = init.random_normal([256, 256], stddev=0.01, name="W2")
  38. W3 = init.random_normal([256, 256], stddev=0.01, name="W3")
  39. W4 = init.random_normal(
  40. [256 + 26*embedding_size + 13, 1], stddev=0.01, name="W4")
  41. fc1 = ht.matmul_op(flatten, W1)
  42. relu1 = ht.relu_op(fc1)
  43. fc2 = ht.matmul_op(relu1, W2)
  44. relu2 = ht.relu_op(fc2)
  45. y3 = ht.matmul_op(relu2, W3)
  46. y4 = ht.concat_op(cross_output, y3, axis=1)
  47. y = ht.matmul_op(y4, W4)
  48. y = ht.sigmoid_op(y)
  49. loss = ht.binarycrossentropy_op(y, y_)
  50. loss = ht.reduce_mean_op(loss, [0])
  51. opt = ht.optim.SGDOptimizer(learning_rate=learning_rate)
  52. train_op = opt.minimize(loss)
  53. return loss, y, y_, train_op

分布式深度学习系统

Contributors (1)