You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiVariableTest.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using Tensorflow.Eager;
  7. namespace TensorFlowNET.UnitTest.Eager
  8. {
  9. /// <summary>
  10. /// tensorflow\c\eager\c_api_test.cc
  11. /// </summary>
  12. [TestClass]
  13. public class CApiVariableTest : CApiTest, IDisposable
  14. {
  15. Status status = new Status();
  16. ContextOptions opts = new ContextOptions();
  17. Context ctx;
  18. [TestMethod]
  19. public void Variables()
  20. {
  21. ctx = new Context(opts, status);
  22. ASSERT_EQ(TF_Code.TF_OK, status.Code);
  23. opts.Dispose();
  24. var var_handle = CreateVariable(ctx, 12.0F);
  25. ASSERT_EQ(TF_OK, TF_GetCode(status));
  26. }
  27. private IntPtr CreateVariable(Context ctx, float value)
  28. {
  29. // Create the variable handle.
  30. var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status);
  31. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  32. c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT);
  33. c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
  34. c_api.TFE_OpSetAttrString(op, "container", "", 0);
  35. c_api.TFE_OpSetAttrString(op, "shared_name", "", 0);
  36. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  37. var var_handle = IntPtr.Zero;
  38. int[] num_retvals = { 1 };
  39. c_api.TFE_Execute(op, var_handle, num_retvals, status);
  40. c_api.TFE_DeleteOp(op);
  41. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  42. ASSERT_EQ(1, num_retvals);
  43. // Assign 'value' to it.
  44. op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status);
  45. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  46. c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  47. c_api.TFE_OpAddInput(op, var_handle, status);
  48. // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
  49. var t = new Tensor(value);
  50. var value_handle = c_api.TFE_NewTensorHandle(t);
  51. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  52. c_api.TFE_OpAddInput(op, value_handle, status);
  53. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  54. num_retvals = new int[] { 0 };
  55. c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status);
  56. c_api.TFE_DeleteOp(op);
  57. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  58. ASSERT_EQ(0, num_retvals);
  59. return var_handle;
  60. }
  61. public void Dispose()
  62. {
  63. status.Dispose();
  64. opts.Dispose();
  65. ctx.Dispose();
  66. }
  67. }
  68. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。