You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wdl_criteo.py 1.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. import time
  5. def wdl_criteo(dense_input, sparse_input, y_):
  6. feature_dimension = 33762577
  7. embedding_size = 128
  8. learning_rate = 0.01
  9. Embedding = init.random_normal(
  10. [feature_dimension, embedding_size], stddev=0.01, name="snd_order_embedding", ctx=ht.cpu(0))
  11. sparse_input = ht.embedding_lookup_op(
  12. Embedding, sparse_input, ctx=ht.cpu(0))
  13. sparse_input = ht.array_reshape_op(sparse_input, (-1, 26*embedding_size))
  14. # DNN
  15. flatten = dense_input
  16. W1 = init.random_normal([13, 256], stddev=0.01, name="W1")
  17. W2 = init.random_normal([256, 256], stddev=0.01, name="W2")
  18. W3 = init.random_normal([256, 256], stddev=0.01, name="W3")
  19. W4 = init.random_normal(
  20. [256 + 26*embedding_size, 1], stddev=0.01, name="W4")
  21. fc1 = ht.matmul_op(flatten, W1)
  22. relu1 = ht.relu_op(fc1)
  23. fc2 = ht.matmul_op(relu1, W2)
  24. relu2 = ht.relu_op(fc2)
  25. y3 = ht.matmul_op(relu2, W3)
  26. y4 = ht.concat_op(sparse_input, y3, axis=1)
  27. y = ht.matmul_op(y4, W4)
  28. y = ht.sigmoid_op(y)
  29. loss = ht.binarycrossentropy_op(y, y_)
  30. loss = ht.reduce_mean_op(loss, [0])
  31. opt = ht.optim.SGDOptimizer(learning_rate=learning_rate)
  32. train_op = opt.minimize(loss)
  33. return loss, y, y_, train_op

分布式深度学习系统

Contributors (1)