diff --git a/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs
new file mode 100644
index 00000000..86e2a4fd
--- /dev/null
+++ b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs
@@ -0,0 +1,40 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using Tensorflow.Util;
+
+namespace Tensorflow.Device
+{
+ public sealed class SafeDeviceListHandle : SafeTensorflowHandle
+ {
+ private SafeDeviceListHandle()
+ {
+ }
+
+ public SafeDeviceListHandle(IntPtr handle)
+ : base(handle)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteDeviceList(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs
index 8715ad26..f0ff00ad 100644
--- a/src/TensorFlowNET.Core/Device/c_api.device.cs
+++ b/src/TensorFlowNET.Core/Device/c_api.device.cs
@@ -16,6 +16,7 @@
using System;
using System.Runtime.InteropServices;
+using Tensorflow.Device;
using Tensorflow.Eager;
namespace Tensorflow
@@ -36,7 +37,7 @@ namespace Tensorflow
/// TF_DeviceList*
///
[DllImport(TensorFlowLibName)]
- public static extern int TF_DeviceListCount(IntPtr list);
+ public static extern int TF_DeviceListCount(SafeDeviceListHandle list);
///
/// Retrieves the type of the device at the given index.
@@ -46,7 +47,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status);
+ public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status);
///
/// Deallocates the device list.
@@ -77,6 +78,6 @@ namespace Tensorflow
///
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status);
+ public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index 1b2fb6b8..a8652beb 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -1,5 +1,6 @@
using System;
using System.Runtime.InteropServices;
+using Tensorflow.Device;
using Tensorflow.Eager;
using TFE_Executor = System.IntPtr;
@@ -317,7 +318,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);
+ public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);
///
///
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
index 07abf18e..3353f531 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
+using Tensorflow.Device;
using Tensorflow.Eager;
using Buffer = System.Buffer;
@@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest
{
public class CApiTest
{
- protected TF_Code TF_OK = TF_Code.TF_OK;
- protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
- protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
+ protected static readonly TF_Code TF_OK = TF_Code.TF_OK;
+ protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
+ protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
protected void EXPECT_TRUE(bool expected, string msg = "")
=> Assert.IsTrue(expected, msg);
- protected void EXPECT_EQ(object expected, object actual, string msg = "")
+ protected static void EXPECT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);
protected void CHECK_EQ(object expected, object actual, string msg = "")
@@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest
protected TF_Code TF_GetCode(Status s)
=> s.Code;
- protected TF_Code TF_GetCode(SafeStatusHandle s)
+ protected static TF_Code TF_GetCode(SafeStatusHandle s)
=> c_api.TF_GetCode(s);
- protected string TF_Message(SafeStatusHandle s)
+ protected static string TF_Message(SafeStatusHandle s)
=> c_api.StringPiece(c_api.TF_Message(s));
protected SafeStatusHandle TF_NewStatus()
@@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
- protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
+ protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
=> c_api.TFE_ContextListDevices(ctx, status);
- protected int TF_DeviceListCount(IntPtr list)
+ protected int TF_DeviceListCount(SafeDeviceListHandle list)
=> c_api.TF_DeviceListCount(list);
- protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status)
+ protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));
- protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status)
+ protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TF_DeviceListName(list, index, status));
- protected void TF_DeleteDeviceList(IntPtr list)
- => c_api.TF_DeleteDeviceList(list);
-
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
index 84b2f54c..ea737e3a 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
+++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using System;
using Tensorflow;
+using Tensorflow.Device;
using Tensorflow.Eager;
namespace TensorFlowNET.UnitTest.NativeAPI
@@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return c_api.TFE_NewContext(opts, status);
}
- IntPtr devices;
- using (var ctx = NewContext(status))
+ static SafeDeviceListHandle ListDevices(SafeStatusHandle status)
{
- devices = c_api.TFE_ContextListDevices(ctx, status);
+ using var ctx = NewContext(status);
+ var devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ return devices;
}
+ using var devices = ListDevices(status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_devices = c_api.TF_DeviceListCount(devices);
@@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status));
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}
-
- c_api.TF_DeleteDeviceList(devices);
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs
index ebf17f7c..4f9f6571 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs
+++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs
@@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI
bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type)
{
- var status = TF_NewStatus();
- var devices = TFE_ContextListDevices(ctx, status);
+ using var status = TF_NewStatus();
+ using var devices = TFE_ContextListDevices(ctx, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
int num_devices = TF_DeviceListCount(devices);
@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI
if (dev_type == device_type)
{
device_name = dev_name;
- TF_DeleteDeviceList(devices);
return true;
}
}
- TF_DeleteDeviceList(devices);
return false;
}