Browse Source

Ensure SafeDeviceListHandle is not released before calling StringPiece

tags/v0.20
Sam Harwell Esther Hu 5 years ago
parent
commit
28e19ea443
2 changed files with 16 additions and 3 deletions
  1. +15
    -2
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +1
    -1
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs

+ 15
- 2
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Runtime.InteropServices;
using Tensorflow.Device;
using Tensorflow.Eager;
using Tensorflow.Util;

namespace Tensorflow
{
@@ -68,6 +69,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)
/// </summary>
/// <param name="list">TF_DeviceList*</param>
/// <param name="index"></param>
/// <param name="status">TF_Status*</param>
public static string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
{
using var _ = list.Lease();
return StringPiece(TF_DeviceListNameImpl(list, index, status));
}

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)
/// The return value will be a pointer to a null terminated string. The caller
@@ -77,7 +90,7 @@ namespace Tensorflow
/// <param name="list">TF_DeviceList*</param>
/// <param name="index"></param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status);
[DllImport(TensorFlowLibName, EntryPoint = "TF_DeviceListName")]
private static extern IntPtr TF_DeviceListNameImpl(SafeDeviceListHandle list, int index, SafeStatusHandle status);
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -147,7 +147,7 @@ namespace TensorFlowNET.UnitTest
=> c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));

protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TF_DeviceListName(list, index, status));
=> c_api.TF_DeviceListName(list, index, status);

protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);


Loading…
Cancel
Save