You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.NativeAPI
  7. {
  8. /// <summary>
  9. /// tensorflow\c\eager\c_api_test.cc
  10. /// </summary>
  11. [TestClass]
  12. public partial class CApiEagerTest : CApiTest
  13. {
  14. IntPtr TestMatrixTensorHandle()
  15. {
  16. var dims = new long[] { 2, 2 };
  17. var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
  18. var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float));
  19. tf.memcpy(c_api.TF_TensorData(t), data, data.Length * sizeof(float));
  20. using var status = c_api.TF_NewStatus();
  21. var th = c_api.TFE_NewTensorHandle(t, status);
  22. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  23. c_api.TF_DeleteTensor(t);
  24. return th;
  25. }
  26. IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b)
  27. {
  28. using var status = TF_NewStatus();
  29. var op = TFE_NewOp(ctx, "MatMul", status);
  30. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  31. TFE_OpAddInput(op, a, status);
  32. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  33. TFE_OpAddInput(op, b, status);
  34. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  35. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  36. return op;
  37. }
  38. bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type)
  39. {
  40. var status = TF_NewStatus();
  41. var devices = TFE_ContextListDevices(ctx, status);
  42. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  43. int num_devices = TF_DeviceListCount(devices);
  44. for (int i = 0; i < num_devices; ++i)
  45. {
  46. var dev_type = TF_DeviceListType(devices, i, status);
  47. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  48. var dev_name = TF_DeviceListName(devices, i, status);
  49. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  50. if (dev_type == device_type)
  51. {
  52. device_name = dev_name;
  53. TF_DeleteDeviceList(devices);
  54. return true;
  55. }
  56. }
  57. TF_DeleteDeviceList(devices);
  58. return false;
  59. }
  60. IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a)
  61. {
  62. using var status = TF_NewStatus();
  63. var op = TFE_NewOp(ctx, "Shape", status);
  64. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  65. TFE_OpAddInput(op, a, status);
  66. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  67. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  68. return op;
  69. }
  70. unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
  71. {
  72. var op = TFE_NewOp(ctx, "VarHandleOp", status);
  73. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  74. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  75. TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
  76. TFE_OpSetAttrString(op, "container", "", 0);
  77. TFE_OpSetAttrString(op, "shared_name", "", 0);
  78. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  79. var var_handle = new IntPtr[1];
  80. int num_retvals = 1;
  81. TFE_Execute(op, var_handle, ref num_retvals, status);
  82. TFE_DeleteOp(op);
  83. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  84. CHECK_EQ(1, num_retvals);
  85. // Assign 'value' to it.
  86. op = TFE_NewOp(ctx, "AssignVariableOp", status);
  87. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  88. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  89. TFE_OpAddInput(op, var_handle[0], status);
  90. // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
  91. var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
  92. tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));
  93. var value_handle = c_api.TFE_NewTensorHandle(t, status);
  94. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  95. TFE_OpAddInput(op, value_handle, status);
  96. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  97. num_retvals = 0;
  98. c_api.TFE_Execute(op, null, ref num_retvals, status);
  99. TFE_DeleteOp(op);
  100. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  101. CHECK_EQ(0, num_retvals);
  102. return var_handle[0];
  103. }
  104. IntPtr TestAxisTensorHandle()
  105. {
  106. var dims = new long[] { 1 };
  107. var data = new int[] { 1 };
  108. var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int));
  109. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  110. using var status = TF_NewStatus();
  111. var th = c_api.TFE_NewTensorHandle(t, status);
  112. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  113. TF_DeleteTensor(t);
  114. return th;
  115. }
  116. IntPtr TestScalarTensorHandle(bool value)
  117. {
  118. var data = new[] { value };
  119. var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
  120. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  121. using var status = TF_NewStatus();
  122. var th = TFE_NewTensorHandle(t, status);
  123. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  124. TF_DeleteTensor(t);
  125. return th;
  126. }
  127. IntPtr TestScalarTensorHandle(float value)
  128. {
  129. var data = new [] { value };
  130. var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));
  131. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  132. using var status = TF_NewStatus();
  133. var th = TFE_NewTensorHandle(t, status);
  134. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  135. TF_DeleteTensor(t);
  136. return th;
  137. }
  138. }
  139. }