@@ -139,6 +139,17 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); | public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); | ||||
/// <summary> | |||||
/// This function will block till the operation that produces `h` has | |||||
/// completed. The memory returned might alias the internal memory used by | |||||
/// TensorFlow. | |||||
/// </summary> | |||||
/// <param name="h">TFE_TensorHandle*</param> | |||||
/// <param name="status">TF_Status*</param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// This function will block till the operation that produces `h` has completed. | /// This function will block till the operation that produces `h` has completed. | ||||
/// </summary> | /// </summary> | ||||
@@ -156,5 +167,12 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); | public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="h">TFE_TensorHandle*</param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFE_DeleteTensorHandle(IntPtr h); | |||||
} | } | ||||
} | } |
@@ -15,6 +15,9 @@ namespace TensorFlowNET.UnitTest | |||||
protected void EXPECT_EQ(object expected, object actual, string msg = "") | protected void EXPECT_EQ(object expected, object actual, string msg = "") | ||||
=> Assert.AreEqual(expected, actual, msg); | => Assert.AreEqual(expected, actual, msg); | ||||
protected void CHECK_EQ(object expected, object actual, string msg = "") | |||||
=> Assert.AreEqual(expected, actual, msg); | |||||
protected void EXPECT_NE(object expected, object actual, string msg = "") | protected void EXPECT_NE(object expected, object actual, string msg = "") | ||||
=> Assert.AreNotEqual(expected, actual, msg); | => Assert.AreNotEqual(expected, actual, msg); | ||||
@@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
} | } | ||||
c_api.TF_DeleteDeviceList(devices); | |||||
// c_api.TF_DeleteDeviceList(devices); | |||||
c_api.TF_DeleteStatus(status); | c_api.TF_DeleteStatus(status); | ||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,38 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using Tensorflow; | |||||
using Tensorflow.Eager; | |||||
using Buffer = System.Buffer; | |||||
namespace TensorFlowNET.UnitTest.Eager | |||||
{ | |||||
public partial class CApiEagerTest | |||||
{ | |||||
/// <summary> | |||||
/// TEST(CAPI, TensorHandle) | |||||
/// </summary> | |||||
[TestMethod] | |||||
public unsafe void TensorHandle() | |||||
{ | |||||
var h = TestMatrixTensorHandle(); | |||||
EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); | |||||
var status = c_api.TF_NewStatus(); | |||||
var t = c_api.TFE_TensorHandleResolve(h, status); | |||||
ASSERT_EQ(16ul, c_api.TF_TensorByteSize(t)); | |||||
var data = new float[] { 0f, 0f, 0f, 0f }; | |||||
fixed (void* src = &data[0]) | |||||
{ | |||||
Buffer.MemoryCopy((void*)c_api.TF_TensorData(t), src, data.Length * sizeof(float), (long)c_api.TF_TensorByteSize(t)); | |||||
} | |||||
EXPECT_EQ(1.0f, data[0]); | |||||
EXPECT_EQ(2.0f, data[1]); | |||||
EXPECT_EQ(3.0f, data[2]); | |||||
EXPECT_EQ(4.0f, data[3]); | |||||
c_api.TF_DeleteTensor(t); | |||||
c_api.TFE_DeleteTensorHandle(h); | |||||
} | |||||
} | |||||
} |
@@ -2,6 +2,7 @@ | |||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Buffer = System.Buffer; | |||||
namespace TensorFlowNET.UnitTest.Eager | namespace TensorFlowNET.UnitTest.Eager | ||||
{ | { | ||||
@@ -11,7 +12,22 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
[TestClass] | [TestClass] | ||||
public partial class CApiEagerTest : CApiTest | public partial class CApiEagerTest : CApiTest | ||||
{ | { | ||||
unsafe IntPtr TestMatrixTensorHandle() | |||||
{ | |||||
var dims = new long[] { 2, 2 }; | |||||
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | |||||
var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float)); | |||||
fixed(void *src = &data[0]) | |||||
{ | |||||
Buffer.MemoryCopy(src, (void*)c_api.TF_TensorData(t), (long)c_api.TF_TensorByteSize(t), data.Length * sizeof(float)); | |||||
} | |||||
var status = c_api.TF_NewStatus(); | |||||
var th = c_api.TFE_NewTensorHandle(t, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
c_api.TF_DeleteTensor(t); | |||||
c_api.TF_DeleteStatus(status); | |||||
return th; | |||||
} | |||||
} | } | ||||
} | } |