|
|
@@ -0,0 +1,105 @@ |
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System.Text; |
|
|
|
using Tensorflow; |
|
|
|
|
|
|
|
namespace TensorFlowNET.UnitTest |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
/// tensorflow\c\c_api_test.cc |
|
|
|
/// `class CApiColocationTest` |
|
|
|
/// </summary> |
|
|
|
[TestClass] |
|
|
|
public class CApiColocationTest : CApiTest, IDisposable |
|
|
|
{ |
|
|
|
private Graph graph_ = new Graph(); |
|
|
|
private Status s_ = new Status(); |
|
|
|
private Operation feed1_; |
|
|
|
private Operation feed2_; |
|
|
|
private Operation constant_; |
|
|
|
private OperationDescription desc_; |
|
|
|
|
|
|
|
[TestInitialize] |
|
|
|
public void SetUp() |
|
|
|
{ |
|
|
|
feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); |
|
|
|
feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); |
|
|
|
constant_ = c_test_util.ScalarConst(10, graph_, s_); |
|
|
|
desc_ = graph_.NewOperation("AddN", "add"); |
|
|
|
|
|
|
|
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; |
|
|
|
desc_.AddInputList(inputs); |
|
|
|
} |
|
|
|
|
|
|
|
private void SetViaStringList(OperationDescription desc, string[] list) |
|
|
|
{ |
|
|
|
string[] list_ptrs = new string[list.Length]; |
|
|
|
uint[] list_lens = new uint[list.Length]; |
|
|
|
StringVectorToArrays(list, list_ptrs, list_lens); |
|
|
|
c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length); |
|
|
|
} |
|
|
|
|
|
|
|
private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens) |
|
|
|
{ |
|
|
|
for (int i = 0; i < v.Length; ++i) |
|
|
|
{ |
|
|
|
ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]); |
|
|
|
lens[i] = (uint)v[i].Length; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private void FinishAndVerify(OperationDescription desc, string[] expected) |
|
|
|
{ |
|
|
|
Operation op = c_api.TF_FinishOperation(desc_, s_); |
|
|
|
ASSERT_EQ(TF_Code.TF_OK, s_.Code); |
|
|
|
VerifyCollocation(op, expected); |
|
|
|
} |
|
|
|
|
|
|
|
private void VerifyCollocation(Operation op, string[] expected) |
|
|
|
{ |
|
|
|
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_); |
|
|
|
TF_AttrMetadata m = new TF_AttrMetadata(); |
|
|
|
if (expected.Length == 0) |
|
|
|
{ |
|
|
|
ASSERT_EQ(TF_Code.TF_INVALID_ARGUMENT, s_.Code); |
|
|
|
EXPECT_EQ("Operation 'add' has no attr named '_class'.", s_.Message); |
|
|
|
return; |
|
|
|
} |
|
|
|
EXPECT_EQ(TF_Code.TF_OK, s_.Code); |
|
|
|
EXPECT_EQ(1, m.is_list); |
|
|
|
EXPECT_EQ(expected.Length, m.list_size); |
|
|
|
EXPECT_EQ(TF_AttrType.TF_ATTR_STRING, m.type); |
|
|
|
string[] values = new string[expected.Length]; |
|
|
|
uint[] lens = new uint[expected.Length]; |
|
|
|
string[] storage = new string[m.total_size]; |
|
|
|
//c_api.TF_OperationGetAttrStringList(op, "_class", values, lens, expected.Length, storage, m.total_size, s_); |
|
|
|
EXPECT_EQ(TF_Code.TF_OK, s_.Code); |
|
|
|
for (int i = 0; i < expected.Length; ++i) |
|
|
|
{ |
|
|
|
EXPECT_EQ(expected[i], values[i] + lens[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void ColocateWith() |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void StringList() |
|
|
|
{ |
|
|
|
SetViaStringList(desc_, new string[] { "loc:@feed1" }); |
|
|
|
FinishAndVerify(desc_, new string[] { "loc:@feed1" }); |
|
|
|
} |
|
|
|
|
|
|
|
[TestCleanup] |
|
|
|
public void Dispose() |
|
|
|
{ |
|
|
|
graph_.Dispose(); |
|
|
|
s_.Dispose(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |