You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.Variables.cs 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.Eager;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.NativeAPI
  6. {
  7. public partial class CApiEagerTest
  8. {
  9. /// <summary>
  10. /// TEST(CAPI, Variables)
  11. /// </summary>
  12. [TestMethod]
  13. public unsafe void Variables()
  14. {
  15. using var status = c_api.TF_NewStatus();
  16. static SafeContextHandle NewContext(SafeStatusHandle status)
  17. {
  18. using var opts = c_api.TFE_NewContextOptions();
  19. return c_api.TFE_NewContext(opts, status);
  20. }
  21. using var ctx = NewContext(status);
  22. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  23. using (var var_handle = CreateVariable(ctx, 12.0f, status))
  24. {
  25. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  26. int num_retvals = 1;
  27. var value_handle = new SafeTensorHandleHandle[1];
  28. using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
  29. {
  30. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  31. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  32. TFE_OpAddInput(op, var_handle, status);
  33. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  34. TFE_Execute(op, value_handle, out num_retvals, status);
  35. ASSERT_EQ(1, num_retvals);
  36. }
  37. try
  38. {
  39. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  40. ASSERT_EQ(1, num_retvals);
  41. EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0]));
  42. EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status));
  43. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  44. var value = 0f; // new float[1];
  45. var t = TFE_TensorHandleResolve(value_handle[0], status);
  46. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  47. ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t));
  48. tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float));
  49. c_api.TF_DeleteTensor(t);
  50. EXPECT_EQ(12.0f, value);
  51. }
  52. finally
  53. {
  54. value_handle[0].Dispose();
  55. }
  56. }
  57. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  58. }
  59. }
  60. }