Browse Source

Eager.TensorHandle unit test.

tags/v0.20
Oceania2018 5 years ago
parent
commit
9dedbb5f0e
5 changed files with 78 additions and 3 deletions
  1. +18
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  2. +3
    -0
      test/TensorFlowNET.UnitTest/CApiTest.cs
  3. +1
    -1
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs
  4. +38
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs
  5. +18
    -2
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs

+ 18
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -139,6 +139,17 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h);

/// <summary>
/// This function will block till the operation that produces `h` has
/// completed. The memory returned might alias the internal memory used by
/// TensorFlow.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status);

/// <summary>
/// This function will block till the operation that produces `h` has completed.
/// </summary>
@@ -156,5 +167,12 @@ namespace Tensorflow
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteTensorHandle(IntPtr h);
}
}

+ 3
- 0
test/TensorFlowNET.UnitTest/CApiTest.cs View File

@@ -15,6 +15,9 @@ namespace TensorFlowNET.UnitTest
protected void EXPECT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);

protected void CHECK_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);

protected void EXPECT_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);



+ 1
- 1
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs View File

@@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.Eager
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
}

c_api.TF_DeleteDeviceList(devices);
// c_api.TF_DeleteDeviceList(devices);
c_api.TF_DeleteStatus(status);
}
}


+ 38
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs View File

@@ -0,0 +1,38 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest.Eager
{
public partial class CApiEagerTest
{
/// <summary>
/// TEST(CAPI, TensorHandle)
/// </summary>
[TestMethod]
public unsafe void TensorHandle()
{
var h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h));

var status = c_api.TF_NewStatus();
var t = c_api.TFE_TensorHandleResolve(h, status);
ASSERT_EQ(16ul, c_api.TF_TensorByteSize(t));

var data = new float[] { 0f, 0f, 0f, 0f };
fixed (void* src = &data[0])
{
Buffer.MemoryCopy((void*)c_api.TF_TensorData(t), src, data.Length * sizeof(float), (long)c_api.TF_TensorByteSize(t));
}

EXPECT_EQ(1.0f, data[0]);
EXPECT_EQ(2.0f, data[1]);
EXPECT_EQ(3.0f, data[2]);
EXPECT_EQ(4.0f, data[3]);
c_api.TF_DeleteTensor(t);
c_api.TFE_DeleteTensorHandle(h);
}
}
}

+ 18
- 2
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs View File

@@ -2,6 +2,7 @@
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest.Eager
{
@@ -11,7 +12,22 @@ namespace TensorFlowNET.UnitTest.Eager
[TestClass]
public partial class CApiEagerTest : CApiTest
{


unsafe IntPtr TestMatrixTensorHandle()
{
var dims = new long[] { 2, 2 };
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float));
fixed(void *src = &data[0])
{
Buffer.MemoryCopy(src, (void*)c_api.TF_TensorData(t), (long)c_api.TF_TensorByteSize(t), data.Length * sizeof(float));
}
var status = c_api.TF_NewStatus();
var th = c_api.TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
c_api.TF_DeleteTensor(t);
c_api.TF_DeleteStatus(status);
return th;
}
}
}

Loading…
Cancel
Save