You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datatransfer_op.py 2.0 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. import hetu as ht
  3. def test_dense():
  4. npw = np.random.random((5, 10)).astype(np.float32)
  5. npx = np.random.random((7, 5)).astype(np.float32)
  6. cpuctx = ht.cpu(0)
  7. gpuctx = ht.gpu(0)
  8. X = ht.Variable(name="x")
  9. mid = X + 3
  10. W = ht.Variable(name='w', value=npw, ctx=cpuctx)
  11. y = ht.matmul_op(mid, W)
  12. opt = ht.optim.SGDOptimizer(learning_rate=0.1)
  13. train_op = opt.minimize(y)
  14. executor = ht.Executor([y, train_op], ctx=gpuctx)
  15. pred_y, _ = executor.run(
  16. feed_dict={X: npx}, convert_to_numpy_ret_vals=True)
  17. nppred_y = np.matmul((npx + 3), npw)
  18. np.testing.assert_allclose(pred_y, nppred_y, rtol=1e-6)
  19. new_npw = npw - 0.1 * \
  20. np.matmul((npx + 3).T, np.ones(nppred_y.shape).astype(np.float32))
  21. np.testing.assert_allclose(
  22. executor.config.placeholder_to_arr_map[W].asnumpy(), new_npw, rtol=1e-10)
  23. test_dense()
  24. def test_sparse():
  25. npemb = np.random.random((100, 20)).astype(np.float32)
  26. npind = np.array(np.random.randint(100, size=(10,)))
  27. npw = np.random.random((20, 30)).astype(np.float32)
  28. cpuctx = ht.cpu(0)
  29. gpuctx = ht.gpu(0)
  30. embedding = ht.Variable('embeddingtable', value=npemb, ctx=cpuctx)
  31. index = ht.Variable(name="index", ctx=cpuctx)
  32. W = ht.Variable(name="w", value=npw)
  33. y = ht.embedding_lookup_op(embedding, index) # (10, 20)
  34. y = ht.matmul_op(y, W)
  35. opt = ht.optim.SGDOptimizer(0.1)
  36. train_op = opt.minimize(y)
  37. executor = ht.Executor([y, train_op], ctx=gpuctx)
  38. out, _ = executor.run(feed_dict={index: npind.astype(
  39. np.float32)}, convert_to_numpy_ret_vals=True)
  40. np_out = np.matmul(npemb[npind], npw)
  41. np.testing.assert_allclose(out, np_out, rtol=1e-6)
  42. tmp_grad = np.matmul(np.ones(np_out.shape).astype(np.float32), npw.T)
  43. for i, localid in enumerate(npind):
  44. npemb[localid] -= 0.1 * tmp_grad[i]
  45. np.testing.assert_allclose(
  46. executor.config.placeholder_to_arr_map[embedding].asnumpy(), npemb, rtol=1e-6)
  47. test_sparse()

分布式深度学习系统