Browse Source

Implement SafeDeviceListHandle as a wrapper for TF_DeviceList

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
09600a87f7
6 changed files with 65 additions and 27 deletions
  1. +40
    -0
      src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Device/c_api.device.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +11
    -13
      test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
  5. +6
    -6
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
  6. +2
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs

+ 40
- 0
src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow.Device
{
public sealed class SafeDeviceListHandle : SafeTensorflowHandle
{
private SafeDeviceListHandle()
{
}

public SafeDeviceListHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteDeviceList(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using Tensorflow.Device;
using Tensorflow.Eager;

namespace Tensorflow
@@ -36,7 +37,7 @@ namespace Tensorflow
/// <param name="list">TF_DeviceList*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_DeviceListCount(IntPtr list);
public static extern int TF_DeviceListCount(SafeDeviceListHandle list);

/// <summary>
/// Retrieves the type of the device at the given index.
@@ -46,7 +47,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status);
public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status);

/// <summary>
/// Deallocates the device list.
@@ -77,6 +78,6 @@ namespace Tensorflow
/// <param name="index"></param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status);
public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status);
}
}

+ 2
- 1
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Runtime.InteropServices;
using Tensorflow.Device;
using Tensorflow.Eager;
using TFE_Executor = System.IntPtr;

@@ -317,7 +318,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);

/// <summary>
///


+ 11
- 13
test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;
using Buffer = System.Buffer;

@@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest
{
public class CApiTest
{
protected TF_Code TF_OK = TF_Code.TF_OK;
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
protected static readonly TF_Code TF_OK = TF_Code.TF_OK;
protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL;

protected void EXPECT_TRUE(bool expected, string msg = "")
=> Assert.IsTrue(expected, msg);

protected void EXPECT_EQ(object expected, object actual, string msg = "")
protected static void EXPECT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);

protected void CHECK_EQ(object expected, object actual, string msg = "")
@@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest
protected TF_Code TF_GetCode(Status s)
=> s.Code;

protected TF_Code TF_GetCode(SafeStatusHandle s)
protected static TF_Code TF_GetCode(SafeStatusHandle s)
=> c_api.TF_GetCode(s);

protected string TF_Message(SafeStatusHandle s)
protected static string TF_Message(SafeStatusHandle s)
=> c_api.StringPiece(c_api.TF_Message(s));

protected SafeStatusHandle TF_NewStatus()
@@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));

protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status)
=> c_api.TFE_ContextListDevices(ctx, status);

protected int TF_DeviceListCount(IntPtr list)
protected int TF_DeviceListCount(SafeDeviceListHandle list)
=> c_api.TF_DeviceListCount(list);

protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status)
protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));

protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status)
protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status)
=> c_api.StringPiece(c_api.TF_DeviceListName(list, index, status));

protected void TF_DeleteDeviceList(IntPtr list)
=> c_api.TF_DeleteDeviceList(list);

protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);



+ 6
- 6
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs View File

@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
@@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI
return c_api.TFE_NewContext(opts, status);
}

IntPtr devices;
using (var ctx = NewContext(status))
static SafeDeviceListHandle ListDevices(SafeStatusHandle status)
{
devices = c_api.TFE_ContextListDevices(ctx, status);
using var ctx = NewContext(status);
var devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
return devices;
}

using var devices = ListDevices(status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_devices = c_api.TF_DeviceListCount(devices);
@@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status));
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

c_api.TF_DeleteDeviceList(devices);
}
}
}

+ 2
- 4
test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs View File

@@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI

bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type)
{
var status = TF_NewStatus();
var devices = TFE_ContextListDevices(ctx, status);
using var status = TF_NewStatus();
using var devices = TFE_ContextListDevices(ctx, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_devices = TF_DeviceListCount(devices);
@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI
if (dev_type == device_type)
{
device_name = dev_name;
TF_DeleteDeviceList(devices);
return true;
}
}

TF_DeleteDeviceList(devices);
return false;
}



Loading…
Cancel
Save