diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index 07e19109..a8b1caea 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest { protected TF_Code TF_OK = TF_Code.TF_OK; protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; + protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; protected void EXPECT_TRUE(bool expected, string msg = "") => Assert.IsTrue(expected, msg); @@ -73,6 +74,9 @@ namespace TensorFlowNET.UnitTest protected void TF_DeleteStatus(IntPtr s) => c_api.TF_DeleteStatus(s); + protected void TF_DeleteTensor(IntPtr t) + => c_api.TF_DeleteTensor(t); + protected IntPtr TF_TensorData(IntPtr t) => c_api.TF_TensorData(t); @@ -94,6 +98,9 @@ namespace TensorFlowNET.UnitTest protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status) => c_api.TFE_NewOp(ctx, op_or_function_name, status); + protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status) + => c_api.TFE_NewTensorHandle(t, status); + protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status) => c_api.TFE_Execute(op, retvals, ref num_retvals, status); diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs new file mode 100644 index 00000000..4ce86574 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -0,0 +1,57 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using Tensorflow; +using Tensorflow.Eager; +using Buffer = System.Buffer; +using System.Linq; + +namespace TensorFlowNET.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) + /// + [TestMethod] + public unsafe void OpInferMixedTypeInputListAttrs() + { + var status = TF_NewStatus(); + var opts = TFE_NewContextOptions(); + var ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteContextOptions(opts); + + var condition = TestScalarTensorHandle(true); + var t1 = TestMatrixTensorHandle(); + var t2 = TestAxisTensorHandle(); + var assertOp = TFE_NewOp(ctx, "Assert", status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(assertOp, condition, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var data = new[] { condition, t1, t2 }; + TFE_OpAddInputList(assertOp, data, 3, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + var attr_values = Graph.TFE_GetOpDef("Assert").Attr; + var attr_found = attr_values.First(x => x.Name == "T"); + EXPECT_NE(attr_found, attr_values.Last()); + // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); + //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); + //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); + + var retvals = new IntPtr[1]; + int num_retvals = 1; + TFE_Execute(assertOp, retvals, ref num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + TF_DeleteStatus(status); + TFE_DeleteOp(assertOp); + TFE_DeleteTensorHandle(condition); + TFE_DeleteTensorHandle(t1); + TFE_DeleteTensorHandle(t2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs index 80150600..9363212a 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs @@ -120,5 +120,45 @@ namespace TensorFlowNET.UnitTest.Eager return var_handle[0]; } + + IntPtr TestAxisTensorHandle() + { + var dims = new long[] { 1 }; + var data = new int[] { 1 }; + var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int)); + memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + var status = TF_NewStatus(); + var th = c_api.TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; + } + + IntPtr TestScalarTensorHandle(bool value) + { + var data = new[] { value }; + var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); + memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + var status = TF_NewStatus(); + var th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; + } + + IntPtr TestScalarTensorHandle(float value) + { + var data = new [] { value }; + var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); + memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + var status = TF_NewStatus(); + var th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; + } } }