From 09600a87f7cb365db2b6845852974d859b66d329 Mon Sep 17 00:00:00 2001 From: Sam Harwell Date: Mon, 29 Jun 2020 10:18:05 -0700 Subject: [PATCH] Implement SafeDeviceListHandle as a wrapper for TF_DeviceList --- .../Device/SafeDeviceListHandle.cs | 40 +++++++++++++++++++ src/TensorFlowNET.Core/Device/c_api.device.cs | 7 ++-- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 3 +- .../NativeAPI/CApiTest.cs | 24 +++++------ .../NativeAPI/Eager/CApi.Eager.Context.cs | 12 +++--- .../NativeAPI/Eager/CApi.Eager.cs | 6 +-- 6 files changed, 65 insertions(+), 27 deletions(-) create mode 100644 src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs diff --git a/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs new file mode 100644 index 00000000..86e2a4fd --- /dev/null +++ b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Device +{ + public sealed class SafeDeviceListHandle : SafeTensorflowHandle + { + private SafeDeviceListHandle() + { + } + + public SafeDeviceListHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteDeviceList(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs index 8715ad26..f0ff00ad 100644 --- a/src/TensorFlowNET.Core/Device/c_api.device.cs +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -16,6 +16,7 @@ using System; using System.Runtime.InteropServices; +using Tensorflow.Device; using Tensorflow.Eager; namespace Tensorflow @@ -36,7 +37,7 @@ namespace Tensorflow /// TF_DeviceList* /// [DllImport(TensorFlowLibName)] - public static extern int TF_DeviceListCount(IntPtr list); + public static extern int TF_DeviceListCount(SafeDeviceListHandle list); /// /// Retrieves the type of the device at the given index. @@ -46,7 +47,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status); + public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status); /// /// Deallocates the device list. @@ -77,6 +78,6 @@ namespace Tensorflow /// /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status); + public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 1b2fb6b8..a8652beb 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -1,5 +1,6 @@ using System; using System.Runtime.InteropServices; +using Tensorflow.Device; using Tensorflow.Eager; using TFE_Executor = System.IntPtr; @@ -317,7 +318,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); + public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); /// /// diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 07abf18e..3353f531 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Device; using Tensorflow.Eager; using Buffer = System.Buffer; @@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest { public class CApiTest { - protected TF_Code TF_OK = TF_Code.TF_OK; - protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; - protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; + protected static readonly TF_Code TF_OK = TF_Code.TF_OK; + protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; + protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL; protected void EXPECT_TRUE(bool expected, string msg = "") => Assert.IsTrue(expected, msg); - protected void EXPECT_EQ(object expected, object actual, string msg = "") + protected static void EXPECT_EQ(object expected, object actual, string msg = "") => Assert.AreEqual(expected, actual, msg); protected void CHECK_EQ(object expected, object actual, string msg = "") @@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest protected TF_Code TF_GetCode(Status s) => s.Code; - protected TF_Code TF_GetCode(SafeStatusHandle s) + protected static TF_Code TF_GetCode(SafeStatusHandle s) => c_api.TF_GetCode(s); - protected string TF_Message(SafeStatusHandle s) + protected static string TF_Message(SafeStatusHandle s) => c_api.StringPiece(c_api.TF_Message(s)); protected SafeStatusHandle TF_NewStatus() @@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); - protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) + protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) => c_api.TFE_ContextListDevices(ctx, status); - protected int TF_DeviceListCount(IntPtr list) + protected int TF_DeviceListCount(SafeDeviceListHandle list) => c_api.TF_DeviceListCount(list); - protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status) + protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status) => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); - protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status) + protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); - protected void TF_DeleteDeviceList(IntPtr list) - => c_api.TF_DeleteDeviceList(list); - protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs index 84b2f54c..ea737e3a 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs @@ -1,6 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; using Tensorflow; +using Tensorflow.Device; using Tensorflow.Eager; namespace TensorFlowNET.UnitTest.NativeAPI @@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI return c_api.TFE_NewContext(opts, status); } - IntPtr devices; - using (var ctx = NewContext(status)) + static SafeDeviceListHandle ListDevices(SafeStatusHandle status) { - devices = c_api.TFE_ContextListDevices(ctx, status); + using var ctx = NewContext(status); + var devices = c_api.TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + return devices; } + using var devices = ListDevices(status); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); int num_devices = c_api.TF_DeviceListCount(devices); @@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); } - - c_api.TF_DeleteDeviceList(devices); } } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs index ebf17f7c..4f9f6571 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs @@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) { - var status = TF_NewStatus(); - var devices = TFE_ContextListDevices(ctx, status); + using var status = TF_NewStatus(); + using var devices = TFE_ContextListDevices(ctx, status); CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); int num_devices = TF_DeviceListCount(devices); @@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI if (dev_type == device_type) { device_name = dev_name; - TF_DeleteDeviceList(devices); return true; } } - TF_DeleteDeviceList(devices); return false; }