@@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
protected TF_Code TF_OK = TF_Code.TF_OK; | protected TF_Code TF_OK = TF_Code.TF_OK; | ||||
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | ||||
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||||
protected void EXPECT_TRUE(bool expected, string msg = "") | protected void EXPECT_TRUE(bool expected, string msg = "") | ||||
=> Assert.IsTrue(expected, msg); | => Assert.IsTrue(expected, msg); | ||||
@@ -73,6 +74,9 @@ namespace TensorFlowNET.UnitTest | |||||
protected void TF_DeleteStatus(IntPtr s) | protected void TF_DeleteStatus(IntPtr s) | ||||
=> c_api.TF_DeleteStatus(s); | => c_api.TF_DeleteStatus(s); | ||||
protected void TF_DeleteTensor(IntPtr t) | |||||
=> c_api.TF_DeleteTensor(t); | |||||
protected IntPtr TF_TensorData(IntPtr t) | protected IntPtr TF_TensorData(IntPtr t) | ||||
=> c_api.TF_TensorData(t); | => c_api.TF_TensorData(t); | ||||
@@ -94,6 +98,9 @@ namespace TensorFlowNET.UnitTest | |||||
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status) | protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status) | ||||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status) | |||||
=> c_api.TFE_NewTensorHandle(t, status); | |||||
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status) | protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status) | ||||
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); | => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | ||||
@@ -0,0 +1,57 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow; | |||||
using Tensorflow.Eager; | |||||
using Buffer = System.Buffer; | |||||
using System.Linq; | |||||
namespace TensorFlowNET.UnitTest.Eager | |||||
{ | |||||
public partial class CApiEagerTest | |||||
{ | |||||
/// <summary> | |||||
/// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) | |||||
/// </summary> | |||||
[TestMethod] | |||||
public unsafe void OpInferMixedTypeInputListAttrs() | |||||
{ | |||||
var status = TF_NewStatus(); | |||||
var opts = TFE_NewContextOptions(); | |||||
var ctx = TFE_NewContext(opts, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_DeleteContextOptions(opts); | |||||
var condition = TestScalarTensorHandle(true); | |||||
var t1 = TestMatrixTensorHandle(); | |||||
var t2 = TestAxisTensorHandle(); | |||||
var assertOp = TFE_NewOp(ctx, "Assert", status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TFE_OpAddInput(assertOp, condition, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var data = new[] { condition, t1, t2 }; | |||||
TFE_OpAddInputList(assertOp, data, 3, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||||
var attr_found = attr_values.First(x => x.Name == "T"); | |||||
EXPECT_NE(attr_found, attr_values.Last()); | |||||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||||
var retvals = new IntPtr[1]; | |||||
int num_retvals = 1; | |||||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TF_DeleteStatus(status); | |||||
TFE_DeleteOp(assertOp); | |||||
TFE_DeleteTensorHandle(condition); | |||||
TFE_DeleteTensorHandle(t1); | |||||
TFE_DeleteTensorHandle(t2); | |||||
TFE_DeleteTensorHandle(retvals[0]); | |||||
TFE_DeleteContext(ctx); | |||||
} | |||||
} | |||||
} |
@@ -120,5 +120,45 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
return var_handle[0]; | return var_handle[0]; | ||||
} | } | ||||
IntPtr TestAxisTensorHandle() | |||||
{ | |||||
var dims = new long[] { 1 }; | |||||
var data = new int[] { 1 }; | |||||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int)); | |||||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||||
var status = TF_NewStatus(); | |||||
var th = c_api.TFE_NewTensorHandle(t, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TF_DeleteTensor(t); | |||||
TF_DeleteStatus(status); | |||||
return th; | |||||
} | |||||
IntPtr TestScalarTensorHandle(bool value) | |||||
{ | |||||
var data = new[] { value }; | |||||
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | |||||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||||
var status = TF_NewStatus(); | |||||
var th = TFE_NewTensorHandle(t, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TF_DeleteTensor(t); | |||||
TF_DeleteStatus(status); | |||||
return th; | |||||
} | |||||
IntPtr TestScalarTensorHandle(float value) | |||||
{ | |||||
var data = new [] { value }; | |||||
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | |||||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||||
var status = TF_NewStatus(); | |||||
var th = TFE_NewTensorHandle(t, status); | |||||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
TF_DeleteTensor(t); | |||||
TF_DeleteStatus(status); | |||||
return th; | |||||
} | |||||
} | } | ||||
} | } |