@@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
protected TF_Code TF_OK = TF_Code.TF_OK; | |||
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
protected void EXPECT_TRUE(bool expected, string msg = "") | |||
=> Assert.IsTrue(expected, msg); | |||
@@ -73,6 +74,9 @@ namespace TensorFlowNET.UnitTest | |||
protected void TF_DeleteStatus(IntPtr s) | |||
=> c_api.TF_DeleteStatus(s); | |||
protected void TF_DeleteTensor(IntPtr t) | |||
=> c_api.TF_DeleteTensor(t); | |||
protected IntPtr TF_TensorData(IntPtr t) | |||
=> c_api.TF_TensorData(t); | |||
@@ -94,6 +98,9 @@ namespace TensorFlowNET.UnitTest | |||
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status) | |||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status) | |||
=> c_api.TFE_NewTensorHandle(t, status); | |||
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status) | |||
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status); | |||
@@ -0,0 +1,57 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
using Buffer = System.Buffer; | |||
using System.Linq; | |||
namespace TensorFlowNET.UnitTest.Eager | |||
{ | |||
public partial class CApiEagerTest | |||
{ | |||
/// <summary> | |||
/// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) | |||
/// </summary> | |||
[TestMethod] | |||
public unsafe void OpInferMixedTypeInputListAttrs() | |||
{ | |||
var status = TF_NewStatus(); | |||
var opts = TFE_NewContextOptions(); | |||
var ctx = TFE_NewContext(opts, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_DeleteContextOptions(opts); | |||
var condition = TestScalarTensorHandle(true); | |||
var t1 = TestMatrixTensorHandle(); | |||
var t2 = TestAxisTensorHandle(); | |||
var assertOp = TFE_NewOp(ctx, "Assert", status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TFE_OpAddInput(assertOp, condition, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var data = new[] { condition, t1, t2 }; | |||
TFE_OpAddInputList(assertOp, data, 3, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
var attr_found = attr_values.First(x => x.Name == "T"); | |||
EXPECT_NE(attr_found, attr_values.Last()); | |||
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
var retvals = new IntPtr[1]; | |||
int num_retvals = 1; | |||
TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TF_DeleteStatus(status); | |||
TFE_DeleteOp(assertOp); | |||
TFE_DeleteTensorHandle(condition); | |||
TFE_DeleteTensorHandle(t1); | |||
TFE_DeleteTensorHandle(t2); | |||
TFE_DeleteTensorHandle(retvals[0]); | |||
TFE_DeleteContext(ctx); | |||
} | |||
} | |||
} |
@@ -120,5 +120,45 @@ namespace TensorFlowNET.UnitTest.Eager | |||
return var_handle[0]; | |||
} | |||
IntPtr TestAxisTensorHandle() | |||
{ | |||
var dims = new long[] { 1 }; | |||
var data = new int[] { 1 }; | |||
var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int)); | |||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
var status = TF_NewStatus(); | |||
var th = c_api.TFE_NewTensorHandle(t, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TF_DeleteTensor(t); | |||
TF_DeleteStatus(status); | |||
return th; | |||
} | |||
IntPtr TestScalarTensorHandle(bool value) | |||
{ | |||
var data = new[] { value }; | |||
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | |||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
var status = TF_NewStatus(); | |||
var th = TFE_NewTensorHandle(t, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TF_DeleteTensor(t); | |||
TF_DeleteStatus(status); | |||
return th; | |||
} | |||
IntPtr TestScalarTensorHandle(float value) | |||
{ | |||
var data = new [] { value }; | |||
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | |||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
var status = TF_NewStatus(); | |||
var th = TFE_NewTensorHandle(t, status); | |||
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
TF_DeleteTensor(t); | |||
TF_DeleteStatus(status); | |||
return th; | |||
} | |||
} | |||
} |