Browse Source

Eager Context unit test.

tags/v0.20
Oceania2018 5 years ago
parent
commit
48829d7206
8 changed files with 140 additions and 132 deletions
  1. +27
    -0
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +26
    -17
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  3. +1
    -0
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  4. +28
    -37
      test/TensorFlowNET.UnitTest/CApiTest.cs
  5. +40
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs
  6. +17
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs
  7. +0
    -78
      test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs
  8. +1
    -0
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

+ 27
- 0
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -28,5 +28,32 @@ namespace Tensorflow
/// <param name="device"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetDevice(IntPtr desc, string device);

/// <summary>
/// Counts the number of elements in the device list.
/// </summary>
/// <param name="list">TF_DeviceList*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_DeviceListCount(IntPtr list);

/// <summary>
/// Deallocates the device list.
/// </summary>
/// <param name="list">TF_DeviceList*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteDeviceList(IntPtr list);

/// <summary>
/// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)
/// The return value will be a pointer to a null terminated string. The caller
/// must not modify or delete the string. It will be deallocated upon a call to
/// TF_DeleteDeviceList.
/// </summary>
/// <param name="list">TF_DeviceList*</param>
/// <param name="index"></param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern string TF_DeviceListName(IntPtr list, int index, IntPtr status);
}
}

+ 26
- 17
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -10,14 +10,14 @@ namespace Tensorflow
/// </summary>
/// <returns>TFE_ContextOptions*</returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_NewContextOptions();
public static extern IntPtr TFE_NewContextOptions();

/// <summary>
/// Destroy an options object.
/// </summary>
/// <param name="options">TFE_ContextOptions*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_DeleteContextOptions(IntPtr options);
public static extern void TFE_DeleteContextOptions(IntPtr options);

/// <summary>
///
@@ -26,14 +26,14 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TFE_Context*</returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status);
public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_DeleteContext(IntPtr ctx);
public static extern void TFE_DeleteContext(IntPtr ctx);

/// <summary>
/// Execute the operation defined by 'op' and return handles to computed
@@ -44,7 +44,7 @@ namespace Tensorflow
/// <param name="num_retvals">int*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status);
public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status);

/// <summary>
///
@@ -54,14 +54,14 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);
public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_DeleteOp(IntPtr op);
public static extern void TFE_DeleteOp(IntPtr op);

/// <summary>
///
@@ -70,10 +70,10 @@ namespace Tensorflow
/// <param name="attr_name">const char*</param>
/// <param name="value">TF_DataType</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);
public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);

[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value);
public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value);

/// <summary>
///
@@ -84,7 +84,7 @@ namespace Tensorflow
/// <param name="num_dims">const int</param>
/// <param name="out_status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status);
public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status);

/// <summary>
///
@@ -94,7 +94,7 @@ namespace Tensorflow
/// <param name="value">const void*</param>
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);
public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);

/// <summary>
///
@@ -103,7 +103,7 @@ namespace Tensorflow
/// <param name="device_name"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status);
public static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status);

/// <summary>
///
@@ -112,7 +112,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status);
public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status);

/// <summary>
///
@@ -120,7 +120,7 @@ namespace Tensorflow
/// <param name="t">const tensorflow::Tensor&</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status);
public static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status);

/// <summary>
///
@@ -129,7 +129,7 @@ namespace Tensorflow
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status);
public static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status);

/// <summary>
///
@@ -137,7 +137,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TFE_TensorHandleDataType(IntPtr h);
public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h);

/// <summary>
/// This function will block till the operation that produces `h` has completed.
@@ -146,6 +146,15 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status);
public static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="ctx">TFE_Context*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status);
}
}

+ 1
- 0
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -36,6 +36,7 @@ https://tensorflownet.readthedocs.io</Description>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;SERIALIZABLE_</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">


+ 28
- 37
test/TensorFlowNET.UnitTest/CApiTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;

namespace TensorFlowNET.UnitTest
@@ -8,59 +9,49 @@ namespace TensorFlowNET.UnitTest
protected TF_Code TF_OK = TF_Code.TF_OK;
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;

protected void EXPECT_TRUE(bool expected)
{
Assert.IsTrue(expected);
}
protected void EXPECT_TRUE(bool expected, string msg = "")
=> Assert.IsTrue(expected, msg);

protected void EXPECT_EQ(object expected, object actual)
{
Assert.AreEqual(expected, actual);
}
protected void EXPECT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);

protected void ASSERT_EQ(object expected, object actual)
{
Assert.AreEqual(expected, actual);
}
protected void EXPECT_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);

protected void ASSERT_TRUE(bool condition)
{
Assert.IsTrue(condition);
}
protected void EXPECT_GE(int expected, int actual, string msg = "")
=> Assert.IsTrue(expected >= actual, msg);

protected void ASSERT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);

protected void ASSERT_TRUE(bool condition, string msg = "")
=> Assert.IsTrue(condition, msg);

protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName)
{
return c_api.TF_NewOperation(graph, opType, opName);
}
=> c_api.TF_NewOperation(graph, opType, opName);

protected void TF_AddInput(OperationDescription desc, TF_Output input)
{
c_api.TF_AddInput(desc, input);
}
=> c_api.TF_AddInput(desc, input);

protected Operation TF_FinishOperation(OperationDescription desc, Status s)
{
return c_api.TF_FinishOperation(desc, s);
}
=> c_api.TF_FinishOperation(desc, s);

protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s)
{
c_api.TF_SetAttrTensor(desc, attrName, value, s);
}
=> c_api.TF_SetAttrTensor(desc, attrName, value, s);

protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
{
c_api.TF_SetAttrType(desc, attrName, dtype);
}
=> c_api.TF_SetAttrType(desc, attrName, dtype);

protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
{
c_api.TF_SetAttrBool(desc, attrName, value);
}
=> c_api.TF_SetAttrBool(desc, attrName, value);

protected TF_Code TF_GetCode(Status s)
{
return s.Code;
}
=> s.Code;

protected TF_Code TF_GetCode(IntPtr s)
=> c_api.TF_GetCode(s);

protected string TF_Message(IntPtr s)
=> c_api.StringPiece(c_api.TF_Message(s));
}
}

+ 40
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs View File

@@ -0,0 +1,40 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.Eager
{
public partial class CApiEagerTest
{
/// <summary>
/// TEST(CAPI, Context)
/// </summary>
[TestMethod]
public void Context()
{
var status = c_api.TF_NewStatus();
var opts = c_api.TFE_NewContextOptions();
var ctx = c_api.TFE_NewContext(opts, status);

c_api.TFE_DeleteContextOptions(opts);

var devices = c_api.TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

c_api.TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

int num_devices = c_api.TF_DeviceListCount(devices);
EXPECT_GE(num_devices, 1, TF_Message(status));
for (int i = 0; i < num_devices; ++i)
{
EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status));
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

c_api.TF_DeleteDeviceList(devices);
c_api.TF_DeleteStatus(status);
}
}
}

+ 17
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs View File

@@ -0,0 +1,17 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.Eager
{
/// <summary>
/// tensorflow\c\eager\c_api_test.cc
/// </summary>
[TestClass]
public partial class CApiEagerTest : CApiTest
{


}
}

+ 0
- 78
test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs View File

@@ -1,78 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.Eager
{
/// <summary>
/// tensorflow\c\eager\c_api_test.cc
/// </summary>
[TestClass]
public class CApiVariableTest : CApiTest, IDisposable
{
Status status = new Status();
ContextOptions opts = new ContextOptions();
Context ctx;

//[TestMethod]
public void Variables()
{
ctx = new Context(opts, status);
ASSERT_EQ(TF_Code.TF_OK, status.Code);
opts.Dispose();

var var_handle = CreateVariable(ctx, 12.0F);
ASSERT_EQ(TF_OK, TF_GetCode(status));
}

private IntPtr CreateVariable(Context ctx, float value)
{
// Create the variable handle.
var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;

c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT);
c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
c_api.TFE_OpSetAttrString(op, "container", "", 0);
c_api.TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
var var_handle = IntPtr.Zero;
int[] num_retvals = { 1 };
c_api.TFE_Execute(op, var_handle, num_retvals, status);
c_api.TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
ASSERT_EQ(1, num_retvals);

// Assign 'value' to it.
op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
c_api.TFE_OpAddInput(op, var_handle, status);

// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
var t = new Tensor(value);

var value_handle = c_api.TFE_NewTensorHandle(t);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;

c_api.TFE_OpAddInput(op, value_handle, status);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;

num_retvals = new int[] { 0 };
c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status);
c_api.TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
ASSERT_EQ(0, num_retvals);

return var_handle;
}

public void Dispose()
{
status.Dispose();
opts.Dispose();
ctx.Dispose();
}
}
}

+ 1
- 0
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -17,6 +17,7 @@
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">


Loading…
Cancel
Save