You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wdl_adult.py 1.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import hetu as ht
  2. from hetu import init
  3. def wdl_adult(X_deep, X_wide, y_):
  4. lr = 5 / 128
  5. dim_wide = 809
  6. dim_deep = 68
  7. W = init.random_normal([dim_wide+20, 2], stddev=0.1, name="W")
  8. W1 = init.random_normal([dim_deep, 50], stddev=0.1, name="W1")
  9. b1 = init.random_normal([50], stddev=0.1, name="b1")
  10. W2 = init.random_normal([50, 20], stddev=0.1, name="W2")
  11. b2 = init.random_normal([20], stddev=0.1, name="b2")
  12. # deep
  13. Embedding = []
  14. X_deep_input = None
  15. for i in range(8):
  16. Embedding_name = "Embedding_deep_" + str(i)
  17. Embedding.append(init.random_normal(
  18. [50, 8], stddev=0.1, name=Embedding_name))
  19. now = ht.embedding_lookup_op(Embedding[i], X_deep[i])
  20. now = ht.array_reshape_op(now, (-1, 8))
  21. if X_deep_input is None:
  22. X_deep_input = now
  23. else:
  24. X_deep_input = ht.concat_op(X_deep_input, now, 1)
  25. for i in range(4):
  26. now = ht.array_reshape_op(X_deep[i + 8], (-1, 1))
  27. X_deep_input = ht.concat_op(X_deep_input, now, 1)
  28. mat1 = ht.matmul_op(X_deep_input, W1)
  29. add1 = mat1 + ht.broadcastto_op(b1, mat1)
  30. relu1 = ht.relu_op(add1)
  31. dropout1 = relu1
  32. mat2 = ht.matmul_op(dropout1, W2)
  33. add2 = mat2 + ht.broadcastto_op(b2, mat2)
  34. relu2 = ht.relu_op(add2)
  35. dropout2 = relu2
  36. dmodel = dropout2
  37. # wide
  38. wmodel = ht.concat_op(X_wide, dmodel, 1)
  39. wmodel = ht.matmul_op(wmodel, W)
  40. prediction = wmodel
  41. loss = ht.softmaxcrossentropy_op(prediction, y_)
  42. loss = ht.reduce_mean_op(loss, [0])
  43. opt = ht.optim.SGDOptimizer(learning_rate=lr)
  44. train_op = opt.minimize(loss)
  45. return loss, prediction, y_, train_op

分布式深度学习系统

Contributors (1)