diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs index f0ff00ad..7ba6fed7 100644 --- a/src/TensorFlowNET.Core/Device/c_api.device.cs +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -18,6 +18,7 @@ using System; using System.Runtime.InteropServices; using Tensorflow.Device; using Tensorflow.Eager; +using Tensorflow.Util; namespace Tensorflow { @@ -68,6 +69,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); + /// + /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) + /// + /// TF_DeviceList* + /// + /// TF_Status* + public static string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) + { + using var _ = list.Lease(); + return StringPiece(TF_DeviceListNameImpl(list, index, status)); + } + /// /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) /// The return value will be a pointer to a null terminated string. The caller @@ -77,7 +90,7 @@ namespace Tensorflow /// TF_DeviceList* /// /// TF_Status* - [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status); + [DllImport(TensorFlowLibName, EntryPoint = "TF_DeviceListName")] + private static extern IntPtr TF_DeviceListNameImpl(SafeDeviceListHandle list, int index, SafeStatusHandle status); } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs index 4b9f061a..0a5a0545 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs @@ -147,7 +147,7 @@ namespace TensorFlowNET.UnitTest => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) - => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); + => c_api.TF_DeviceListName(list, index, status); protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);