You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiTest.cs 7.6 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow.Device;
  4. using Tensorflow.Eager;
  5. namespace Tensorflow.Native.UnitTest
  6. {
  7. public class CApiTest
  8. {
  9. protected static readonly TF_Code TF_OK = TF_Code.TF_OK;
  10. protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
  11. protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
  12. protected void EXPECT_TRUE(bool expected, string msg = "")
  13. => Assert.IsTrue(expected, msg);
  14. protected static void EXPECT_EQ(object expected, object actual, string msg = "")
  15. => Assert.AreEqual(expected, actual, msg);
  16. protected void CHECK_EQ(object expected, object actual, string msg = "")
  17. => Assert.AreEqual(expected, actual, msg);
  18. protected void EXPECT_NE(object expected, object actual, string msg = "")
  19. => Assert.AreNotEqual(expected, actual, msg);
  20. protected void CHECK_NE(object expected, object actual, string msg = "")
  21. => Assert.AreNotEqual(expected, actual, msg);
  22. protected void EXPECT_GE(int expected, int actual, string msg = "")
  23. => Assert.IsTrue(expected >= actual, msg);
  24. protected void ASSERT_EQ(object expected, object actual, string msg = "")
  25. => Assert.AreEqual(expected, actual, msg);
  26. protected void ASSERT_NE(object expected, object actual, string msg = "")
  27. => Assert.AreNotEqual(expected, actual, msg);
  28. protected void ASSERT_TRUE(bool condition, string msg = "")
  29. => Assert.IsTrue(condition, msg);
  30. protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName)
  31. => c_api.TF_NewOperation(graph, opType, opName);
  32. protected void TF_AddInput(OperationDescription desc, TF_Output input)
  33. => c_api.TF_AddInput(desc, input);
  34. protected Operation TF_FinishOperation(OperationDescription desc, Status s)
  35. => c_api.TF_FinishOperation(desc, s.Handle);
  36. protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s)
  37. => c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle);
  38. protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
  39. => c_api.TF_SetAttrType(desc, attrName, dtype);
  40. protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
  41. => c_api.TF_SetAttrBool(desc, attrName, value);
  42. protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h)
  43. => c_api.TFE_TensorHandleDataType(h);
  44. protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status)
  45. => c_api.TFE_TensorHandleNumDims(h, status);
  46. protected TF_Code TF_GetCode(Status s)
  47. => s.Code;
  48. protected static TF_Code TF_GetCode(SafeStatusHandle s)
  49. => c_api.TF_GetCode(s);
  50. protected static string TF_Message(SafeStatusHandle s)
  51. => c_api.StringPiece(c_api.TF_Message(s));
  52. protected SafeStatusHandle TF_NewStatus()
  53. => c_api.TF_NewStatus();
  54. protected IntPtr TF_TensorData(SafeTensorHandle t)
  55. => c_api.TF_TensorData(t);
  56. protected ulong TF_TensorByteSize(SafeTensorHandle t)
  57. => c_api.TF_TensorByteSize(t);
  58. protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status)
  59. => c_api.TFE_OpAddInput(op, h, status);
  60. protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value)
  61. => c_api.TFE_OpSetAttrType(op, attr_name, value);
  62. protected void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status)
  63. => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status);
  64. protected void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, uint length)
  65. => c_api.TFE_OpSetAttrString(op, attr_name, value, length);
  66. protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
  67. => c_api.TFE_NewOp(ctx, op_or_function_name, status);
  68. protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status)
  69. => c_api.TFE_NewTensorHandle(t, status);
  70. protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
  71. => c_api.TFE_Execute(op, retvals, out num_retvals, status);
  72. protected SafeContextOptionsHandle TFE_NewContextOptions()
  73. => c_api.TFE_NewContextOptions();
  74. protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status)
  75. => c_api.TFE_NewContext(opts, status);
  76. protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
  77. => c_api.TFE_OpGetInputLength(op, input_name, status);
  78. protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status)
  79. => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);
  80. protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status)
  81. => c_api.TFE_OpGetOutputLength(op, input_name, status);
  82. protected void TFE_DeleteTensorHandle(IntPtr h)
  83. => c_api.TFE_DeleteTensorHandle(h);
  84. protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
  85. => c_api.TFE_ContextGetExecutorForThread(ctx);
  86. protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status)
  87. => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);
  88. protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status)
  89. => c_api.TFE_TensorHandleResolve(h, status);
  90. protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
  91. => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));
  92. protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
  93. => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
  94. protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
  95. => c_api.TFE_ContextListDevices(ctx, status);
  96. protected int TF_DeviceListCount(SafeDeviceListHandle list)
  97. => c_api.TF_DeviceListCount(list);
  98. protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status)
  99. => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));
  100. protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
  101. => c_api.TF_DeviceListName(list, index, status);
  102. protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
  103. => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
  104. protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status)
  105. => c_api.TFE_OpSetDevice(op, device_name, status);
  106. }
  107. }