@@ -0,0 +1,40 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Device | |||
{ | |||
public sealed class SafeDeviceListHandle : SafeTensorflowHandle | |||
{ | |||
private SafeDeviceListHandle() | |||
{ | |||
} | |||
public SafeDeviceListHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TF_DeleteDeviceList(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Device; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
@@ -36,7 +37,7 @@ namespace Tensorflow | |||
/// <param name="list">TF_DeviceList*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_DeviceListCount(IntPtr list); | |||
public static extern int TF_DeviceListCount(SafeDeviceListHandle list); | |||
/// <summary> | |||
/// Retrieves the type of the device at the given index. | |||
@@ -46,7 +47,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status); | |||
public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||
/// <summary> | |||
/// Deallocates the device list. | |||
@@ -77,6 +78,6 @@ namespace Tensorflow | |||
/// <param name="index"></param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status); | |||
public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Device; | |||
using Tensorflow.Eager; | |||
using TFE_Executor = System.IntPtr; | |||
@@ -317,7 +318,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -1,6 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Device; | |||
using Tensorflow.Eager; | |||
using Buffer = System.Buffer; | |||
@@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
public class CApiTest | |||
{ | |||
protected TF_Code TF_OK = TF_Code.TF_OK; | |||
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | |||
protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
protected void EXPECT_TRUE(bool expected, string msg = "") | |||
=> Assert.IsTrue(expected, msg); | |||
protected void EXPECT_EQ(object expected, object actual, string msg = "") | |||
protected static void EXPECT_EQ(object expected, object actual, string msg = "") | |||
=> Assert.AreEqual(expected, actual, msg); | |||
protected void CHECK_EQ(object expected, object actual, string msg = "") | |||
@@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest | |||
protected TF_Code TF_GetCode(Status s) | |||
=> s.Code; | |||
protected TF_Code TF_GetCode(SafeStatusHandle s) | |||
protected static TF_Code TF_GetCode(SafeStatusHandle s) | |||
=> c_api.TF_GetCode(s); | |||
protected string TF_Message(SafeStatusHandle s) | |||
protected static string TF_Message(SafeStatusHandle s) | |||
=> c_api.StringPiece(c_api.TF_Message(s)); | |||
protected SafeStatusHandle TF_NewStatus() | |||
@@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest | |||
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | |||
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | |||
protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||
protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||
=> c_api.TFE_ContextListDevices(ctx, status); | |||
protected int TF_DeviceListCount(IntPtr list) | |||
protected int TF_DeviceListCount(SafeDeviceListHandle list) | |||
=> c_api.TF_DeviceListCount(list); | |||
protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status) | |||
protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||
=> c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); | |||
protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status) | |||
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||
=> c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); | |||
protected void TF_DeleteDeviceList(IntPtr list) | |||
=> c_api.TF_DeleteDeviceList(list); | |||
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
@@ -1,6 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Device; | |||
using Tensorflow.Eager; | |||
namespace TensorFlowNET.UnitTest.NativeAPI | |||
@@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
return c_api.TFE_NewContext(opts, status); | |||
} | |||
IntPtr devices; | |||
using (var ctx = NewContext(status)) | |||
static SafeDeviceListHandle ListDevices(SafeStatusHandle status) | |||
{ | |||
devices = c_api.TFE_ContextListDevices(ctx, status); | |||
using var ctx = NewContext(status); | |||
var devices = c_api.TFE_ContextListDevices(ctx, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
return devices; | |||
} | |||
using var devices = ListDevices(status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_devices = c_api.TF_DeviceListCount(devices); | |||
@@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
} | |||
c_api.TF_DeleteDeviceList(devices); | |||
} | |||
} | |||
} |
@@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | |||
{ | |||
var status = TF_NewStatus(); | |||
var devices = TFE_ContextListDevices(ctx, status); | |||
using var status = TF_NewStatus(); | |||
using var devices = TFE_ContextListDevices(ctx, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
int num_devices = TF_DeviceListCount(devices); | |||
@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
if (dev_type == device_type) | |||
{ | |||
device_name = dev_name; | |||
TF_DeleteDeviceList(devices); | |||
return true; | |||
} | |||
} | |||
TF_DeleteDeviceList(devices); | |||
return false; | |||
} | |||