You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Eager.cs 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow.Eager;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Native.UnitTest.Eager
  6. {
  7. /// <summary>
  8. /// tensorflow\c\eager\c_api_test.cc
  9. /// </summary>
  10. [TestClass]
  11. public partial class CApiEagerTest : CApiTest
  12. {
  13. SafeTensorHandleHandle TestMatrixTensorHandle()
  14. {
  15. var dims = new long[] { 2, 2 };
  16. var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
  17. var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float));
  18. tf.memcpy(c_api.TF_TensorData(t), data, data.Length * sizeof(float));
  19. using var status = c_api.TF_NewStatus();
  20. var th = c_api.TFE_NewTensorHandle(t, status);
  21. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  22. t.Dispose();
  23. return th;
  24. }
  25. SafeOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b)
  26. {
  27. using var status = TF_NewStatus();
  28. var op = TFE_NewOp(ctx, "MatMul", status);
  29. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  30. TFE_OpAddInput(op, a, status);
  31. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  32. TFE_OpAddInput(op, b, status);
  33. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  34. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  35. return op;
  36. }
  37. bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type)
  38. {
  39. using var status = TF_NewStatus();
  40. using var devices = TFE_ContextListDevices(ctx, status);
  41. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  42. int num_devices = TF_DeviceListCount(devices);
  43. for (int i = 0; i < num_devices; ++i)
  44. {
  45. var dev_type = TF_DeviceListType(devices, i, status);
  46. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  47. var dev_name = TF_DeviceListName(devices, i, status);
  48. CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status));
  49. if (dev_type == device_type)
  50. {
  51. device_name = dev_name;
  52. return true;
  53. }
  54. }
  55. return false;
  56. }
  57. SafeOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a)
  58. {
  59. using var status = TF_NewStatus();
  60. var op = TFE_NewOp(ctx, "Shape", status);
  61. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  62. TFE_OpAddInput(op, a, status);
  63. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  64. TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
  65. return op;
  66. }
  67. unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
  68. {
  69. var var_handle = new SafeTensorHandleHandle[1];
  70. int num_retvals;
  71. using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
  72. {
  73. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  74. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  75. TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
  76. TFE_OpSetAttrString(op, "container", "", 0);
  77. TFE_OpSetAttrString(op, "shared_name", "", 0);
  78. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  79. TFE_Execute(op, var_handle, out num_retvals, status);
  80. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  81. CHECK_EQ(1, num_retvals);
  82. }
  83. // Assign 'value' to it.
  84. using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
  85. {
  86. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  87. TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
  88. TFE_OpAddInput(op, var_handle[0], status);
  89. // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
  90. var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
  91. tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));
  92. var value_handle = c_api.TFE_NewTensorHandle(t, status);
  93. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  94. TFE_OpAddInput(op, value_handle, status);
  95. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  96. c_api.TFE_Execute(op, null, out num_retvals, status);
  97. if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero);
  98. CHECK_EQ(0, num_retvals);
  99. }
  100. return var_handle[0];
  101. }
  102. SafeTensorHandleHandle TestAxisTensorHandle()
  103. {
  104. var dims = new long[] { 1 };
  105. var data = new int[] { 1 };
  106. var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int));
  107. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  108. using var status = TF_NewStatus();
  109. var th = c_api.TFE_NewTensorHandle(t, status);
  110. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  111. t.Dispose();
  112. return th;
  113. }
  114. SafeTensorHandleHandle TestScalarTensorHandle(bool value)
  115. {
  116. var data = new[] { value };
  117. var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
  118. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  119. using var status = TF_NewStatus();
  120. var th = TFE_NewTensorHandle(t, status);
  121. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  122. t.Dispose();
  123. return th;
  124. }
  125. SafeTensorHandleHandle TestScalarTensorHandle(float value)
  126. {
  127. var data = new[] { value };
  128. var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));
  129. tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
  130. using var status = TF_NewStatus();
  131. var th = TFE_NewTensorHandle(t, status);
  132. CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  133. t.Dispose();
  134. return th;
  135. }
  136. }
  137. }