You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deepfm_criteo.py 2.1 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. import time
  5. def dfm_criteo(dense_input, sparse_input, y_):
  6. feature_dimension = 33762577
  7. embedding_size = 128
  8. learning_rate = 0.01
  9. # FM
  10. Embedding1 = init.random_normal(
  11. [feature_dimension, 1], stddev=0.01, name="fst_order_embedding", ctx=ht.cpu(0))
  12. FM_W = init.random_normal([13, 1], stddev=0.01, name="dense_parameter")
  13. sparse_1dim_input = ht.embedding_lookup_op(
  14. Embedding1, sparse_input, ctx=ht.cpu(0))
  15. fm_dense_part = ht.matmul_op(dense_input, FM_W)
  16. fm_sparse_part = ht.reduce_sum_op(sparse_1dim_input, axes=1)
  17. # fst order output
  18. y1 = fm_dense_part + fm_sparse_part
  19. Embedding2 = init.random_normal(
  20. [feature_dimension, embedding_size], stddev=0.01, name="snd_order_embedding", ctx=ht.cpu(0))
  21. sparse_2dim_input = ht.embedding_lookup_op(
  22. Embedding2, sparse_input, ctx=ht.cpu(0))
  23. sparse_2dim_sum = ht.reduce_sum_op(sparse_2dim_input, axes=1)
  24. sparse_2dim_sum_square = ht.mul_op(sparse_2dim_sum, sparse_2dim_sum)
  25. sparse_2dim_square = ht.mul_op(sparse_2dim_input, sparse_2dim_input)
  26. sparse_2dim_square_sum = ht.reduce_sum_op(sparse_2dim_square, axes=1)
  27. sparse_2dim = sparse_2dim_sum_square + -1 * sparse_2dim_square_sum
  28. sparse_2dim_half = sparse_2dim * 0.5
  29. # snd order output
  30. y2 = ht.reduce_sum_op(sparse_2dim_half, axes=1, keepdims=True)
  31. # DNN
  32. flatten = ht.array_reshape_op(sparse_2dim_input, (-1, 26*embedding_size))
  33. W1 = init.random_normal([26*embedding_size, 256], stddev=0.01, name="W1")
  34. W2 = init.random_normal([256, 256], stddev=0.01, name="W2")
  35. W3 = init.random_normal([256, 1], stddev=0.01, name="W3")
  36. fc1 = ht.matmul_op(flatten, W1)
  37. relu1 = ht.relu_op(fc1)
  38. fc2 = ht.matmul_op(relu1, W2)
  39. relu2 = ht.relu_op(fc2)
  40. y3 = ht.matmul_op(relu2, W3)
  41. y4 = y1 + y2
  42. y = y4 + y3
  43. y = ht.sigmoid_op(y)
  44. loss = ht.binarycrossentropy_op(y, y_)
  45. loss = ht.reduce_mean_op(loss, [0])
  46. opt = ht.optim.SGDOptimizer(learning_rate=learning_rate)
  47. train_op = opt.minimize(loss)
  48. return loss, y, y_, train_op

分布式深度学习系统

Contributors (1)