From 48829d72063874f76c8945d334329f4c318e2540 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Mar 2020 09:59:43 -0600 Subject: [PATCH] Eager Context unit test. --- src/TensorFlowNET.Core/Device/c_api.device.cs | 27 +++++++ src/TensorFlowNET.Core/Eager/c_api.eager.cs | 43 ++++++---- .../TensorFlow.Binding.csproj | 1 + test/TensorFlowNET.UnitTest/CApiTest.cs | 65 +++++++--------- .../Eager/CApi.Eager.Context.cs | 40 ++++++++++ .../Eager/CApi.Eager.cs | 17 ++++ .../Eager/CApiVariableTest.cs | 78 ------------------- .../Tensorflow.UnitTest.csproj | 1 + 8 files changed, 140 insertions(+), 132 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs create mode 100644 test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs delete mode 100644 test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs index 2ce79a3e..b7b88bee 100644 --- a/src/TensorFlowNET.Core/Device/c_api.device.cs +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -28,5 +28,32 @@ namespace Tensorflow /// [DllImport(TensorFlowLibName)] public static extern void TF_SetDevice(IntPtr desc, string device); + + /// + /// Counts the number of elements in the device list. + /// + /// TF_DeviceList* + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_DeviceListCount(IntPtr list); + + /// + /// Deallocates the device list. + /// + /// TF_DeviceList* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteDeviceList(IntPtr list); + + /// + /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) + /// The return value will be a pointer to a null terminated string. The caller + /// must not modify or delete the string. It will be deallocated upon a call to + /// TF_DeleteDeviceList. + /// + /// TF_DeviceList* + /// + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern string TF_DeviceListName(IntPtr list, int index, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 15a872e0..9e89d234 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -10,14 +10,14 @@ namespace Tensorflow /// /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] - internal static extern IntPtr TFE_NewContextOptions(); + public static extern IntPtr TFE_NewContextOptions(); /// /// Destroy an options object. /// /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] - internal static extern void TFE_DeleteContextOptions(IntPtr options); + public static extern void TFE_DeleteContextOptions(IntPtr options); /// /// @@ -26,14 +26,14 @@ namespace Tensorflow /// TF_Status* /// TFE_Context* [DllImport(TensorFlowLibName)] - internal static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); + public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); /// /// /// /// TFE_Context* [DllImport(TensorFlowLibName)] - internal static extern void TFE_DeleteContext(IntPtr ctx); + public static extern void TFE_DeleteContext(IntPtr ctx); /// /// Execute the operation defined by 'op' and return handles to computed @@ -44,7 +44,7 @@ namespace Tensorflow /// int* /// TF_Status* [DllImport(TensorFlowLibName)] - internal static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status); + public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status); /// /// @@ -54,14 +54,14 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - internal static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); + public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); /// /// /// /// TFE_Op* [DllImport(TensorFlowLibName)] - internal static extern void TFE_DeleteOp(IntPtr op); + public static extern void TFE_DeleteOp(IntPtr op); /// /// @@ -70,10 +70,10 @@ namespace Tensorflow /// const char* /// TF_DataType [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); + public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); + public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); /// /// @@ -84,7 +84,7 @@ namespace Tensorflow /// const int /// TF_Status* [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); + public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); /// /// @@ -94,7 +94,7 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); + public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); /// /// @@ -103,7 +103,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status); + public static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status); /// /// @@ -112,7 +112,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// TF_Status* [DllImport(TensorFlowLibName)] - internal static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); + public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); /// /// @@ -120,7 +120,7 @@ namespace Tensorflow /// const tensorflow::Tensor& /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] - internal static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status); + public static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status); /// /// @@ -129,7 +129,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - internal static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status); + public static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status); /// /// @@ -137,7 +137,7 @@ namespace Tensorflow /// TFE_TensorHandle* /// [DllImport(TensorFlowLibName)] - internal static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); + public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); /// /// This function will block till the operation that produces `h` has completed. @@ -146,6 +146,15 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - internal static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status); + public static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status); + + /// + /// + /// + /// TFE_Context* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 0fcb21a4..b1cbd1f1 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -36,6 +36,7 @@ https://tensorflownet.readthedocs.io true TRACE;DEBUG;SERIALIZABLE_ + x64 diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index c3cdf227..5a7c29f4 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using Tensorflow; namespace TensorFlowNET.UnitTest @@ -8,59 +9,49 @@ namespace TensorFlowNET.UnitTest protected TF_Code TF_OK = TF_Code.TF_OK; protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; - protected void EXPECT_TRUE(bool expected) - { - Assert.IsTrue(expected); - } + protected void EXPECT_TRUE(bool expected, string msg = "") + => Assert.IsTrue(expected, msg); - protected void EXPECT_EQ(object expected, object actual) - { - Assert.AreEqual(expected, actual); - } + protected void EXPECT_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); - protected void ASSERT_EQ(object expected, object actual) - { - Assert.AreEqual(expected, actual); - } + protected void EXPECT_NE(object expected, object actual, string msg = "") + => Assert.AreNotEqual(expected, actual, msg); - protected void ASSERT_TRUE(bool condition) - { - Assert.IsTrue(condition); - } + protected void EXPECT_GE(int expected, int actual, string msg = "") + => Assert.IsTrue(expected >= actual, msg); + + protected void ASSERT_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); + + protected void ASSERT_TRUE(bool condition, string msg = "") + => Assert.IsTrue(condition, msg); protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName) - { - return c_api.TF_NewOperation(graph, opType, opName); - } + => c_api.TF_NewOperation(graph, opType, opName); protected void TF_AddInput(OperationDescription desc, TF_Output input) - { - c_api.TF_AddInput(desc, input); - } + => c_api.TF_AddInput(desc, input); protected Operation TF_FinishOperation(OperationDescription desc, Status s) - { - return c_api.TF_FinishOperation(desc, s); - } + => c_api.TF_FinishOperation(desc, s); protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) - { - c_api.TF_SetAttrTensor(desc, attrName, value, s); - } + => c_api.TF_SetAttrTensor(desc, attrName, value, s); protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) - { - c_api.TF_SetAttrType(desc, attrName, dtype); - } + => c_api.TF_SetAttrType(desc, attrName, dtype); protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) - { - c_api.TF_SetAttrBool(desc, attrName, value); - } + => c_api.TF_SetAttrBool(desc, attrName, value); protected TF_Code TF_GetCode(Status s) - { - return s.Code; - } + => s.Code; + + protected TF_Code TF_GetCode(IntPtr s) + => c_api.TF_GetCode(s); + + protected string TF_Message(IntPtr s) + => c_api.StringPiece(c_api.TF_Message(s)); } } diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs new file mode 100644 index 00000000..05d34d20 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs @@ -0,0 +1,40 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; + +namespace TensorFlowNET.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, Context) + /// + [TestMethod] + public void Context() + { + var status = c_api.TF_NewStatus(); + var opts = c_api.TFE_NewContextOptions(); + var ctx = c_api.TFE_NewContext(opts, status); + + c_api.TFE_DeleteContextOptions(opts); + + var devices = c_api.TFE_ContextListDevices(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + c_api.TFE_DeleteContext(ctx); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + int num_devices = c_api.TF_DeviceListCount(devices); + EXPECT_GE(num_devices, 1, TF_Message(status)); + for (int i = 0; i < num_devices; ++i) + { + EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + + c_api.TF_DeleteDeviceList(devices); + c_api.TF_DeleteStatus(status); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs new file mode 100644 index 00000000..c79ffc78 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs @@ -0,0 +1,17 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; + +namespace TensorFlowNET.UnitTest.Eager +{ + /// + /// tensorflow\c\eager\c_api_test.cc + /// + [TestClass] + public partial class CApiEagerTest : CApiTest + { + + + } +} diff --git a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs b/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs deleted file mode 100644 index 8eed6a8f..00000000 --- a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs +++ /dev/null @@ -1,78 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using Tensorflow; -using Tensorflow.Eager; - -namespace TensorFlowNET.UnitTest.Eager -{ - /// - /// tensorflow\c\eager\c_api_test.cc - /// - [TestClass] - public class CApiVariableTest : CApiTest, IDisposable - { - Status status = new Status(); - ContextOptions opts = new ContextOptions(); - Context ctx; - - //[TestMethod] - public void Variables() - { - ctx = new Context(opts, status); - ASSERT_EQ(TF_Code.TF_OK, status.Code); - opts.Dispose(); - - var var_handle = CreateVariable(ctx, 12.0F); - ASSERT_EQ(TF_OK, TF_GetCode(status)); - } - - private IntPtr CreateVariable(Context ctx, float value) - { - // Create the variable handle. - var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - - c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT); - c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); - c_api.TFE_OpSetAttrString(op, "container", "", 0); - c_api.TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - var var_handle = IntPtr.Zero; - int[] num_retvals = { 1 }; - c_api.TFE_Execute(op, var_handle, num_retvals, status); - c_api.TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - ASSERT_EQ(1, num_retvals); - - // Assign 'value' to it. - op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - c_api.TFE_OpAddInput(op, var_handle, status); - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - var t = new Tensor(value); - - var value_handle = c_api.TFE_NewTensorHandle(t); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - - c_api.TFE_OpAddInput(op, value_handle, status); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - - num_retvals = new int[] { 0 }; - c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status); - c_api.TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; - ASSERT_EQ(0, num_retvals); - - return var_handle; - } - - public void Dispose() - { - status.Dispose(); - opts.Dispose(); - ctx.Dispose(); - } - } -} diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index 6e8aa936..e64571e9 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -17,6 +17,7 @@ DEBUG;TRACE true + x64