You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wdl_adult.py 1.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import hetu as ht
  2. from hetu import init
  3. def wdl_adult(X_deep, X_wide, y_, dense_param_ctx):
  4. lr = 5 / 128
  5. dim_wide = 809
  6. dim_deep = 68
  7. with ht.context(dense_param_ctx):
  8. W = init.random_normal([dim_wide+20, 2], stddev=0.1, name="W")
  9. W1 = init.random_normal([dim_deep, 50], stddev=0.1, name="W1")
  10. b1 = init.random_normal([50], stddev=0.1, name="b1")
  11. W2 = init.random_normal([50, 20], stddev=0.1, name="W2")
  12. b2 = init.random_normal([20], stddev=0.1, name="b2")
  13. # deep
  14. Embedding = []
  15. X_deep_input = None
  16. for i in range(8):
  17. Embedding_name = "Embedding_deep_" + str(i)
  18. Embedding.append(init.random_normal(
  19. [50, 8], stddev=0.1, name=Embedding_name))
  20. now = ht.embedding_lookup_op(Embedding[i], X_deep[i])
  21. now = ht.array_reshape_op(now, (-1, 8))
  22. if X_deep_input is None:
  23. X_deep_input = now
  24. else:
  25. X_deep_input = ht.concat_op(X_deep_input, now, 1)
  26. for i in range(4):
  27. now = ht.array_reshape_op(X_deep[i + 8], (-1, 1))
  28. X_deep_input = ht.concat_op(X_deep_input, now, 1)
  29. mat1 = ht.matmul_op(X_deep_input, W1)
  30. add1 = mat1 + ht.broadcastto_op(b1, mat1)
  31. relu1 = ht.relu_op(add1)
  32. dropout1 = relu1
  33. mat2 = ht.matmul_op(dropout1, W2)
  34. add2 = mat2 + ht.broadcastto_op(b2, mat2)
  35. relu2 = ht.relu_op(add2)
  36. dropout2 = relu2
  37. dmodel = dropout2
  38. # wide
  39. wmodel = ht.concat_op(X_wide, dmodel, 1)
  40. wmodel = ht.matmul_op(wmodel, W)
  41. prediction = wmodel
  42. loss = ht.softmaxcrossentropy_op(prediction, y_)
  43. loss = ht.reduce_mean_op(loss, [0])
  44. opt = ht.optim.SGDOptimizer(learning_rate=lr)
  45. train_op = opt.minimize(loss)
  46. return loss, prediction, y_, train_op