You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.cs 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.NativeAPI
  6. {
  7. /// <summary>
  8. /// tensorflow\c\eager\c_api_test.cc
  9. /// </summary>
  10. [TestClass]
  11. public partial class CApiEagerTest : CApiTest
  12. {
  13. IntPtr TestMatrixTensorHandle()
  14. {
  15. var dims = new long[] { 2, 2 };
  16. var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
  17. var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float));
  18. tf.memcpy(c_api.TF_TensorData(t), data, data.Length * sizeof(float));
  19. var status = c_api.TF_NewStatus();
  20. var th = c_api.TFE_NewTensorHandle(t, status);
  21. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  22. c_api.TF_DeleteTensor(t);
  23. c_api.TF_DeleteStatus(status);
  24. return th;
  25. }
  26. IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b)
  27. {
  28. var status = TF_NewStatus();
  29. var op = TFE_NewOp(ctx, "MatMul", status);
  30. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  31. TFE_OpAddInput(op, a, status);
  32. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  33. TFE_OpAddInput(op, b, status);
  34. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  35. TF_DeleteStatus(status);
  36. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  37. return op;
  38. }
  39. bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type)
  40. {
  41. var status = TF_NewStatus();
  42. var devices = TFE_ContextListDevices(ctx, status);
  43. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  44. int num_devices = TF_DeviceListCount(devices);
  45. for (int i = 0; i < num_devices; ++i)
  46. {
  47. var dev_type = TF_DeviceListType(devices, i, status);
  48. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  49. var dev_name = TF_DeviceListName(devices, i, status);
  50. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  51. if (dev_type == device_type)
  52. {
  53. device_name = dev_name;
  54. TF_DeleteDeviceList(devices);
  55. return true;
  56. }
  57. }
  58. TF_DeleteDeviceList(devices);
  59. return false;
  60. }
  61. IntPtr ShapeOp(IntPtr ctx, IntPtr a)
  62. {
  63. var status = TF_NewStatus();
  64. var op = TFE_NewOp(ctx, "Shape", status);
  65. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  66. TFE_OpAddInput(op, a, status);
  67. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  68. TF_DeleteStatus(status);
  69. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  70. return op;
  71. }
  72. unsafe IntPtr CreateVariable(IntPtr ctx, float value, IntPtr status)
  73. {
  74. var op = TFE_NewOp(ctx, "VarHandleOp", status);
  75. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  76. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  77. TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
  78. TFE_OpSetAttrString(op, "container", "", 0);
  79. TFE_OpSetAttrString(op, "shared_name", "", 0);
  80. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  81. var var_handle = new IntPtr[1];
  82. int num_retvals = 1;
  83. TFE_Execute(op, var_handle, ref num_retvals, status);
  84. TFE_DeleteOp(op);
  85. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  86. CHECK_EQ(1, num_retvals);
  87. // Assign 'value' to it.
  88. op = TFE_NewOp(ctx, "AssignVariableOp", status);
  89. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  90. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  91. TFE_OpAddInput(op, var_handle[0], status);
  92. // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
  93. var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
  94. tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));
  95. var value_handle = c_api.TFE_NewTensorHandle(t, status);
  96. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  97. TFE_OpAddInput(op, value_handle, status);
  98. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  99. num_retvals = 0;
  100. c_api.TFE_Execute(op, null, ref num_retvals, status);
  101. TFE_DeleteOp(op);
  102. if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
  103. CHECK_EQ(0, num_retvals);
  104. return var_handle[0];
  105. }
  106. IntPtr TestAxisTensorHandle()
  107. {
  108. var dims = new long[] { 1 };
  109. var data = new int[] { 1 };
  110. var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int));
  111. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  112. var status = TF_NewStatus();
  113. var th = c_api.TFE_NewTensorHandle(t, status);
  114. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  115. TF_DeleteTensor(t);
  116. TF_DeleteStatus(status);
  117. return th;
  118. }
  119. IntPtr TestScalarTensorHandle(bool value)
  120. {
  121. var data = new[] { value };
  122. var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
  123. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  124. var status = TF_NewStatus();
  125. var th = TFE_NewTensorHandle(t, status);
  126. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  127. TF_DeleteTensor(t);
  128. TF_DeleteStatus(status);
  129. return th;
  130. }
  131. IntPtr TestScalarTensorHandle(float value)
  132. {
  133. var data = new [] { value };
  134. var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));
  135. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  136. var status = TF_NewStatus();
  137. var th = TFE_NewTensorHandle(t, status);
  138. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  139. TF_DeleteTensor(t);
  140. TF_DeleteStatus(status);
  141. return th;
  142. }
  143. }
  144. }