Browse Source

Merge pull request #90 from Esther2013/master

add FinishOperation to OperationDescription
tags/v0.1.0-Tensor
Haiping GitHub 6 years ago
parent
commit
3615dbab12
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 47 deletions
  1. +1
    -13
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Operations/OperationDescription.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  5. +6
    -8
      test/TensorFlowNET.UnitTest/CApiColocationTest.cs
  6. +13
    -13
      test/TensorFlowNET.UnitTest/GraphTest.cs
  7. +0
    -12
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 1
- 13
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -20,19 +20,7 @@ namespace Tensorflow

public OperationDescription NewOperation(string opType, string opName)
{
OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName);
return desc;

/*c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
Status.Check();

c_api.TF_SetAttrType(desc, "dtype", tensor.dtype);

var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();

return op;*/
return c_api.TF_NewOperation(_handle, opType, opName);
}
}
}

+ 11
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -145,6 +145,17 @@ namespace Tensorflow
return ret;
}

public NodeDef GetNodeDef()
{
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
}

public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;



+ 5
- 0
src/TensorFlowNET.Core/Operations/OperationDescription.cs View File

@@ -18,6 +18,11 @@ namespace Tensorflow
c_api.TF_AddInputList(_handle, inputs, inputs.Length);
}

public Operation FinishOperation(Status status)
{
return c_api.TF_FinishOperation(_handle, status);
}

public static implicit operator OperationDescription(IntPtr handle)
{
return new OperationDescription(handle);


+ 1
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -232,7 +232,7 @@ namespace Tensorflow
/// <param name="lengths"></param>
/// <param name="num_values"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values);
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);


+ 6
- 8
test/TensorFlowNET.UnitTest/CApiColocationTest.cs View File

@@ -30,34 +30,32 @@ namespace TensorFlowNET.UnitTest
s_.Check();
constant_ = c_test_util.ScalarConst(10, graph_, s_);
s_.Check();
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add");
s_.Check();

desc_ = graph_.NewOperation("AddN", "add");
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) };
desc_.AddInputList(inputs);
s_.Check();
}

private void SetViaStringList(OperationDescription desc, string[] list)
{
string[] list_ptrs = new string[list.Length];
uint[] list_lens = new uint[list.Length];
var list_ptrs = new IntPtr[list.Length];
var list_lens = new uint[list.Length];
StringVectorToArrays(list, list_ptrs, list_lens);
c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length);
}

private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens)
private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens)
{
for (int i = 0; i < v.Length; ++i)
{
ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]);
ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]);
lens[i] = (uint)v[i].Length;
}
}

private void FinishAndVerify(OperationDescription desc, string[] expected)
{
Operation op = c_api.TF_FinishOperation(desc_, s_);
var op = desc_.FinishOperation(s_);
ASSERT_EQ(TF_Code.TF_OK, s_.Code);
VerifyCollocation(op, expected);
}


+ 13
- 13
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Serialize to NodeDef.
var node_def = c_test_util.GetNodeDef(neg);
var node_def = neg.GetNodeDef();

// Validate NodeDef is what we expect.
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
@@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest
// Look up some nodes by name.
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
EXPECT_EQ(neg, neg2);
var node_def2 = c_test_util.GetNodeDef(neg2);
var node_def2 = neg2.GetNodeDef();
EXPECT_EQ(node_def.ToString(), node_def2.ToString());

Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
EXPECT_EQ(feed, feed2);
node_def = c_test_util.GetNodeDef(feed);
node_def2 = c_test_util.GetNodeDef(feed2);
node_def = feed.GetNodeDef();
node_def2 = feed2.GetNodeDef();
EXPECT_EQ(node_def.ToString(), node_def2.ToString());

// Test iterating through the nodes of a graph.
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest
uint pos = 0;
Operation oper;

while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
{
if (oper.Equals(feed))
{
@@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest
}
else
{
node_def = c_test_util.GetNodeDef(oper);
node_def = oper.GetNodeDef();
Assert.Fail($"Unexpected Node: {node_def.ToString()}");
}
}
@@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(0, neg.GetControlInputs().Length);
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.GetControlOutputs().Length);
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
@@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
Operation scalar2 = graph.OperationByName("imported2/scalar");
Operation feed2 = graph.OperationByName("imported2/feed");
Operation neg2 = graph.OperationByName("imported2/neg");
@@ -287,7 +287,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);
// Check return operation
var return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, return_opers.Length);
@@ -302,26 +302,26 @@ namespace TensorFlowNET.UnitTest
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
var scalar3 = graph.OperationByName("imported3/scalar");
var feed3 = graph.OperationByName("imported3/feed");
var neg3 = graph.OperationByName("imported3/neg");
ASSERT_TRUE(scalar3 != IntPtr.Zero);
ASSERT_TRUE(feed3 != IntPtr.Zero);
ASSERT_TRUE(neg3 != IntPtr.Zero);
// Check that newly-imported scalar and feed have control deps (neg3 will
// inherit them from input)
var control_inputs = scalar3.GetControlInputs();
ASSERT_EQ(2, scalar3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);
control_inputs = feed3.GetControlInputs();
ASSERT_EQ(2, feed3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);
// Export to a graph def so we can import a graph with control dependencies
graph_def.Dispose();
graph_def = new Buffer();


+ 0
- 12
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -51,18 +51,6 @@ namespace TensorFlowNET.UnitTest
return def;
}

public static NodeDef GetNodeDef(Operation oper)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(oper, buffer, s);
s.Check();
var ret = NodeDef.Parser.ParseFrom(buffer);
buffer.Dispose();
s.Dispose();
return ret;
}

public static bool IsAddN(NodeDef node_def, int n)
{
if (node_def.Op != "AddN" || node_def.Name != "add" ||


Loading…
Cancel
Save