You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.Variables.cs 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.NativeAPI
  7. {
  8. public partial class CApiEagerTest
  9. {
  10. /// <summary>
  11. /// TEST(CAPI, Variables)
  12. /// </summary>
  13. [TestMethod]
  14. public unsafe void Variables()
  15. {
  16. using var status = c_api.TF_NewStatus();
  17. static SafeContextHandle NewContext(SafeStatusHandle status)
  18. {
  19. using var opts = c_api.TFE_NewContextOptions();
  20. return c_api.TFE_NewContext(opts, status);
  21. }
  22. using var ctx = NewContext(status);
  23. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  24. using (var var_handle = CreateVariable(ctx, 12.0f, status))
  25. {
  26. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  27. int num_retvals = 1;
  28. var value_handle = new SafeTensorHandleHandle[1];
  29. using (var op = TFE_NewOp(ctx, "ReadVariableOp", status))
  30. {
  31. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  32. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  33. TFE_OpAddInput(op, var_handle, status);
  34. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  35. TFE_Execute(op, value_handle, out num_retvals, status);
  36. ASSERT_EQ(1, num_retvals);
  37. }
  38. try
  39. {
  40. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  41. ASSERT_EQ(1, num_retvals);
  42. EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0]));
  43. EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status));
  44. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  45. var value = 0f; // new float[1];
  46. var t = TFE_TensorHandleResolve(value_handle[0], status);
  47. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  48. ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t));
  49. tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float));
  50. c_api.TF_DeleteTensor(t);
  51. EXPECT_EQ(12.0f, value);
  52. }
  53. finally
  54. {
  55. value_handle[0].Dispose();
  56. }
  57. }
  58. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  59. }
  60. }
  61. }