You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiTest.cs 7.5 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. using Buffer = System.Buffer;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. public class CApiTest
  9. {
  10. protected TF_Code TF_OK = TF_Code.TF_OK;
  11. protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
  12. protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
  13. protected void EXPECT_TRUE(bool expected, string msg = "")
  14. => Assert.IsTrue(expected, msg);
  15. protected void EXPECT_EQ(object expected, object actual, string msg = "")
  16. => Assert.AreEqual(expected, actual, msg);
  17. protected void CHECK_EQ(object expected, object actual, string msg = "")
  18. => Assert.AreEqual(expected, actual, msg);
  19. protected void EXPECT_NE(object expected, object actual, string msg = "")
  20. => Assert.AreNotEqual(expected, actual, msg);
  21. protected void CHECK_NE(object expected, object actual, string msg = "")
  22. => Assert.AreNotEqual(expected, actual, msg);
  23. protected void EXPECT_GE(int expected, int actual, string msg = "")
  24. => Assert.IsTrue(expected >= actual, msg);
  25. protected void ASSERT_EQ(object expected, object actual, string msg = "")
  26. => Assert.AreEqual(expected, actual, msg);
  27. protected void ASSERT_TRUE(bool condition, string msg = "")
  28. => Assert.IsTrue(condition, msg);
  29. protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName)
  30. => c_api.TF_NewOperation(graph, opType, opName);
  31. protected void TF_AddInput(OperationDescription desc, TF_Output input)
  32. => c_api.TF_AddInput(desc, input);
  33. protected Operation TF_FinishOperation(OperationDescription desc, Status s)
  34. => c_api.TF_FinishOperation(desc, s.Handle);
  35. protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s)
  36. => c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle);
  37. protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
  38. => c_api.TF_SetAttrType(desc, attrName, dtype);
  39. protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
  40. => c_api.TF_SetAttrBool(desc, attrName, value);
  41. protected TF_DataType TFE_TensorHandleDataType(IntPtr h)
  42. => c_api.TFE_TensorHandleDataType(h);
  43. protected int TFE_TensorHandleNumDims(IntPtr h, SafeStatusHandle status)
  44. => c_api.TFE_TensorHandleNumDims(h, status);
  45. protected TF_Code TF_GetCode(Status s)
  46. => s.Code;
  47. protected TF_Code TF_GetCode(SafeStatusHandle s)
  48. => c_api.TF_GetCode(s);
  49. protected string TF_Message(SafeStatusHandle s)
  50. => c_api.StringPiece(c_api.TF_Message(s));
  51. protected SafeStatusHandle TF_NewStatus()
  52. => c_api.TF_NewStatus();
  53. protected void TF_DeleteTensor(IntPtr t)
  54. => c_api.TF_DeleteTensor(t);
  55. protected IntPtr TF_TensorData(IntPtr t)
  56. => c_api.TF_TensorData(t);
  57. protected ulong TF_TensorByteSize(IntPtr t)
  58. => c_api.TF_TensorByteSize(t);
  59. protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status)
  60. => c_api.TFE_OpAddInput(op, h, status);
  61. protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value)
  62. => c_api.TFE_OpSetAttrType(op, attr_name, value);
  63. protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status)
  64. => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status);
  65. protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length)
  66. => c_api.TFE_OpSetAttrString(op, attr_name, value, length);
  67. protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status)
  68. => c_api.TFE_NewOp(ctx, op_or_function_name, status);
  69. protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status)
  70. => c_api.TFE_NewTensorHandle(t, status);
  71. protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status)
  72. => c_api.TFE_Execute(op, retvals, ref num_retvals, status);
  73. protected IntPtr TFE_NewContextOptions()
  74. => c_api.TFE_NewContextOptions();
  75. protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status)
  76. => c_api.TFE_NewContext(opts, status);
  77. protected void TFE_DeleteContextOptions(IntPtr opts)
  78. => c_api.TFE_DeleteContextOptions(opts);
  79. protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status)
  80. => c_api.TFE_OpGetInputLength(op, input_name, status);
  81. protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status)
  82. => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);
  83. protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status)
  84. => c_api.TFE_OpGetOutputLength(op, input_name, status);
  85. protected void TFE_DeleteTensorHandle(IntPtr h)
  86. => c_api.TFE_DeleteTensorHandle(h);
  87. protected void TFE_DeleteOp(IntPtr op)
  88. => c_api.TFE_DeleteOp(op);
  89. protected void TFE_DeleteExecutor(IntPtr executor)
  90. => c_api.TFE_DeleteExecutor(executor);
  91. protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx)
  92. => c_api.TFE_ContextGetExecutorForThread(ctx);
  93. protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status)
  94. => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);
  95. protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status)
  96. => c_api.TFE_TensorHandleResolve(h, status);
  97. protected string TFE_TensorHandleDeviceName(IntPtr h, SafeStatusHandle status)
  98. => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));
  99. protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status)
  100. => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
  101. protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
  102. => c_api.TFE_ContextListDevices(ctx, status);
  103. protected int TF_DeviceListCount(IntPtr list)
  104. => c_api.TF_DeviceListCount(list);
  105. protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status)
  106. => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));
  107. protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status)
  108. => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status));
  109. protected void TF_DeleteDeviceList(IntPtr list)
  110. => c_api.TF_DeleteDeviceList(list);
  111. protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
  112. => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
  113. protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status)
  114. => c_api.TFE_OpSetDevice(op, device_name, status);
  115. }
  116. }