Browse Source

TestTFE_OpInferMixedTypeInputListAttrs

tags/v0.20
Oceania2018 5 years ago
parent
commit
f115a05741
3 changed files with 104 additions and 0 deletions
  1. +7
    -0
      test/TensorFlowNET.UnitTest/CApiTest.cs
  2. +57
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  3. +40
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs

+ 7
- 0
test/TensorFlowNET.UnitTest/CApiTest.cs View File

@@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest
{
protected TF_Code TF_OK = TF_Code.TF_OK;
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL;

protected void EXPECT_TRUE(bool expected, string msg = "")
=> Assert.IsTrue(expected, msg);
@@ -73,6 +74,9 @@ namespace TensorFlowNET.UnitTest
protected void TF_DeleteStatus(IntPtr s)
=> c_api.TF_DeleteStatus(s);

protected void TF_DeleteTensor(IntPtr t)
=> c_api.TF_DeleteTensor(t);

protected IntPtr TF_TensorData(IntPtr t)
=> c_api.TF_TensorData(t);

@@ -94,6 +98,9 @@ namespace TensorFlowNET.UnitTest
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);

protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status)
=> c_api.TFE_NewTensorHandle(t, status);

protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status)
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status);



+ 57
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -0,0 +1,57 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;
using System.Linq;

namespace TensorFlowNET.UnitTest.Eager
{
public partial class CApiEagerTest
{
/// <summary>
/// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs)
/// </summary>
[TestMethod]
public unsafe void OpInferMixedTypeInputListAttrs()
{
var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var condition = TestScalarTensorHandle(true);
var t1 = TestMatrixTensorHandle();
var t2 = TestAxisTensorHandle();
var assertOp = TFE_NewOp(ctx, "Assert", status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_OpAddInput(assertOp, condition, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
var data = new[] { condition, t1, t2 };
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var attr_values = Graph.TFE_GetOpDef("Assert").Attr;
var attr_found = attr_values.First(x => x.Name == "T");
EXPECT_NE(attr_found, attr_values.Last());
// EXPECT_EQ(attr_found.Type[0], "DT_BOOL");
//EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
//EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);

var retvals = new IntPtr[1];
int num_retvals = 1;
TFE_Execute(assertOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

TF_DeleteStatus(status);
TFE_DeleteOp(assertOp);
TFE_DeleteTensorHandle(condition);
TFE_DeleteTensorHandle(t1);
TFE_DeleteTensorHandle(t2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
}
}
}

+ 40
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs View File

@@ -120,5 +120,45 @@ namespace TensorFlowNET.UnitTest.Eager

return var_handle[0];
}

IntPtr TestAxisTensorHandle()
{
var dims = new long[] { 1 };
var data = new int[] { 1 };
var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int));
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
var status = TF_NewStatus();
var th = c_api.TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}

IntPtr TestScalarTensorHandle(bool value)
{
var data = new[] { value };
var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool));
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
var status = TF_NewStatus();
var th = TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}

IntPtr TestScalarTensorHandle(float value)
{
var data = new [] { value };
var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float));
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
var status = TF_NewStatus();
var th = TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
}
}

Loading…
Cancel
Save