You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiVariableTest.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. namespace TensorFlowNET.UnitTest.Eager
  6. {
  7. /// <summary>
  8. /// tensorflow\c\eager\c_api_test.cc
  9. /// </summary>
  10. [TestClass]
  11. public class CApiVariableTest : CApiTest, IDisposable
  12. {
  13. Status status = new Status();
  14. ContextOptions opts = new ContextOptions();
  15. Context ctx;
  16. //[TestMethod]
  17. public void Variables()
  18. {
  19. ctx = new Context(opts, status);
  20. ASSERT_EQ(TF_Code.TF_OK, status.Code);
  21. opts.Dispose();
  22. var var_handle = CreateVariable(ctx, 12.0F);
  23. ASSERT_EQ(TF_OK, TF_GetCode(status));
  24. }
  25. private IntPtr CreateVariable(Context ctx, float value)
  26. {
  27. // Create the variable handle.
  28. var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status);
  29. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  30. c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT);
  31. c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
  32. c_api.TFE_OpSetAttrString(op, "container", "", 0);
  33. c_api.TFE_OpSetAttrString(op, "shared_name", "", 0);
  34. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  35. var var_handle = IntPtr.Zero;
  36. int[] num_retvals = { 1 };
  37. c_api.TFE_Execute(op, var_handle, num_retvals, status);
  38. c_api.TFE_DeleteOp(op);
  39. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  40. ASSERT_EQ(1, num_retvals);
  41. // Assign 'value' to it.
  42. op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status);
  43. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  44. c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  45. c_api.TFE_OpAddInput(op, var_handle, status);
  46. // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
  47. var t = new Tensor(value);
  48. var value_handle = c_api.TFE_NewTensorHandle(t);
  49. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  50. c_api.TFE_OpAddInput(op, value_handle, status);
  51. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  52. num_retvals = new int[] { 0 };
  53. c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status);
  54. c_api.TFE_DeleteOp(op);
  55. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  56. ASSERT_EQ(0, num_retvals);
  57. return var_handle;
  58. }
  59. public void Dispose()
  60. {
  61. status.Dispose();
  62. opts.Dispose();
  63. ctx.Dispose();
  64. }
  65. }
  66. }