You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Eager.Variables.cs 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.Eager;
  3. using static Tensorflow.Binding;
  4. namespace Tensorflow.Native.UnitTest.Eager
  5. {
  6. public partial class CApiEagerTest
  7. {
  8. /// <summary>
  9. /// TEST(CAPI, Variables)
  10. /// </summary>
  11. [TestMethod]
  12. public unsafe void Variables()
  13. {
  14. using var status = c_api.TF_NewStatus();
  15. static SafeContextHandle NewContext(SafeStatusHandle status)
  16. {
  17. using var opts = c_api.TFE_NewContextOptions();
  18. return c_api.TFE_NewContext(opts, status);
  19. }
  20. using var ctx = NewContext(status);
  21. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  22. using (var var_handle = CreateVariable(ctx, 12.0f, status))
  23. {
  24. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  25. int num_retvals = 1;
  26. var value_handle = new SafeEagerTensorHandle[1];
  27. using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
  28. {
  29. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  30. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  31. TFE_OpAddInput(op, var_handle, status);
  32. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  33. TFE_Execute(op, value_handle, out num_retvals, status);
  34. ASSERT_EQ(1, num_retvals);
  35. }
  36. try
  37. {
  38. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  39. ASSERT_EQ(1, num_retvals);
  40. EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0]));
  41. EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status));
  42. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  43. var value = 0f; // new float[1];
  44. var t = TFE_TensorHandleResolve(value_handle[0], status);
  45. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  46. ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t));
  47. tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float));
  48. t.Dispose();
  49. EXPECT_EQ(12.0f, value);
  50. }
  51. finally
  52. {
  53. value_handle[0].Dispose();
  54. }
  55. }
  56. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  57. }
  58. }
  59. }