using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Tensorflow;
using Tensorflow.Functions;
using static TensorFlowNET.UnitTest.c_test_util;
namespace TensorFlowNET.UnitTest.NativeAPI
{
///
/// tensorflow\c\c_api_function_test.cc
/// `class CApiColocationTest`
///
[TestClass]
public class CApiFunctionTest : CApiTest, IDisposable
{
Graph func_graph_;
Graph host_graph_;
string func_name_ = "MyFunc";
string func_node_name_ = "MyFunc_0";
Status s_;
IntPtr func_;
[TestInitialize]
public void Initialize()
{
func_graph_ = new Graph();
host_graph_ = new Graph();
s_ = new Status();
}
[TestMethod]
public void OneOp_ZeroInputs_OneOutput()
{
var c = ScalarConst(10, func_graph_, s_, "scalar10");
// Define
Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]);
// Use, run, and verify
var func_op = Use(new Operation[0]);
Run(new KeyValuePair[0], func_op, 10);
VerifyFDef(new[] { "scalar10_0" });
}
void Define(int num_opers, Operation[] opers,
Operation[] inputs, Operation[] outputs,
string[] output_names, bool expect_failure = false)
=> DefineT(num_opers, opers,
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
output_names, expect_failure);
void DefineT(int num_opers, Operation[] opers,
TF_Output[] inputs, TF_Output[] outputs,
string[] output_names, bool expect_failure = false)
{
IntPtr output_names_ptr = IntPtr.Zero;
func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false,
num_opers, num_opers == -1 ? new IntPtr[0] : opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length, inputs.ToArray(),
outputs.Length, outputs.ToArray(),
output_names_ptr, IntPtr.Zero, null, s_.Handle);
// delete output_names_ptr
if (expect_failure)
{
ASSERT_EQ(IntPtr.Zero, func_);
return;
}
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
}
Operation Use(Operation[] inputs)
=> UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray());
Operation UseT(TF_Output[] inputs)
=> UseHelper(inputs);
Operation UseHelper(TF_Output[] inputs)
{
var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_);
foreach (var input in inputs)
TF_AddInput(desc, input);
c_api.TF_SetDevice(desc, "/cpu:0");
var op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
ASSERT_NE(op, IntPtr.Zero);
return op;
}
void Run(KeyValuePair[] inputs, Operation output, int expected_result)
=> Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result });
unsafe void Run(KeyValuePair[] inputs, TF_Output[] outputs, int[] expected_results)
{
var csession = new CSession(host_graph_, s_);
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
csession.SetInputs(inputs);
csession.SetOutputs(outputs);
csession.Run(s_);
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
for (int i = 0; i < expected_results.Length; ++i)
{
var output = csession.output_tensor(i);
ASSERT_NE(output, IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output));
EXPECT_EQ(0, c_api.TF_NumDims(output));
ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output));
var output_contents = c_api.TF_TensorData(output);
EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer());
}
}
void VerifyFDef(string[] nodes)
{
var fdef = GetFunctionDef(func_);
EXPECT_NE(fdef, IntPtr.Zero);
VerifyFDefNodes(fdef, nodes);
}
void VerifyFDefNodes(FunctionDef fdef, string[] nodes)
{
ASSERT_EQ(nodes.Length, fdef.NodeDef.Count);
}
public void Dispose()
{
}
}
}