add FinishOperation to OperationDescriptiontags/v0.1.0-Tensor
@@ -20,19 +20,7 @@ namespace Tensorflow | |||
public OperationDescription NewOperation(string opType, string opName) | |||
{ | |||
OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName); | |||
return desc; | |||
/*c_api.TF_SetAttrTensor(desc, "value", tensor, Status); | |||
Status.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", tensor.dtype); | |||
var op = c_api.TF_FinishOperation(desc, Status); | |||
Status.Check(); | |||
return op;*/ | |||
return c_api.TF_NewOperation(_handle, opType, opName); | |||
} | |||
} | |||
} |
@@ -145,6 +145,17 @@ namespace Tensorflow | |||
return ret; | |||
} | |||
public NodeDef GetNodeDef() | |||
{ | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
s.Check(); | |||
return NodeDef.Parser.ParseFrom(buffer); | |||
} | |||
} | |||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||
public static implicit operator IntPtr(Operation op) => op._handle; | |||
@@ -18,6 +18,11 @@ namespace Tensorflow | |||
c_api.TF_AddInputList(_handle, inputs, inputs.Length); | |||
} | |||
public Operation FinishOperation(Status status) | |||
{ | |||
return c_api.TF_FinishOperation(_handle, status); | |||
} | |||
public static implicit operator OperationDescription(IntPtr handle) | |||
{ | |||
return new OperationDescription(handle); | |||
@@ -232,7 +232,7 @@ namespace Tensorflow | |||
/// <param name="lengths"></param> | |||
/// <param name="num_values"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values); | |||
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||
@@ -30,34 +30,32 @@ namespace TensorFlowNET.UnitTest | |||
s_.Check(); | |||
constant_ = c_test_util.ScalarConst(10, graph_, s_); | |||
s_.Check(); | |||
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); | |||
s_.Check(); | |||
desc_ = graph_.NewOperation("AddN", "add"); | |||
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | |||
desc_.AddInputList(inputs); | |||
s_.Check(); | |||
} | |||
private void SetViaStringList(OperationDescription desc, string[] list) | |||
{ | |||
string[] list_ptrs = new string[list.Length]; | |||
uint[] list_lens = new uint[list.Length]; | |||
var list_ptrs = new IntPtr[list.Length]; | |||
var list_lens = new uint[list.Length]; | |||
StringVectorToArrays(list, list_ptrs, list_lens); | |||
c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length); | |||
} | |||
private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens) | |||
private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens) | |||
{ | |||
for (int i = 0; i < v.Length; ++i) | |||
{ | |||
ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]); | |||
ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]); | |||
lens[i] = (uint)v[i].Length; | |||
} | |||
} | |||
private void FinishAndVerify(OperationDescription desc, string[] expected) | |||
{ | |||
Operation op = c_api.TF_FinishOperation(desc_, s_); | |||
var op = desc_.FinishOperation(s_); | |||
ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||
VerifyCollocation(op, expected); | |||
} | |||
@@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Serialize to NodeDef. | |||
var node_def = c_test_util.GetNodeDef(neg); | |||
var node_def = neg.GetNodeDef(); | |||
// Validate NodeDef is what we expect. | |||
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||
@@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest | |||
// Look up some nodes by name. | |||
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
EXPECT_EQ(neg, neg2); | |||
var node_def2 = c_test_util.GetNodeDef(neg2); | |||
var node_def2 = neg2.GetNodeDef(); | |||
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
EXPECT_EQ(feed, feed2); | |||
node_def = c_test_util.GetNodeDef(feed); | |||
node_def2 = c_test_util.GetNodeDef(feed2); | |||
node_def = feed.GetNodeDef(); | |||
node_def2 = feed2.GetNodeDef(); | |||
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
// Test iterating through the nodes of a graph. | |||
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest | |||
uint pos = 0; | |||
Operation oper; | |||
while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
{ | |||
if (oper.Equals(feed)) | |||
{ | |||
@@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
else | |||
{ | |||
node_def = c_test_util.GetNodeDef(oper); | |||
node_def = oper.GetNodeDef(); | |||
Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | |||
} | |||
} | |||
@@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(0, neg.GetControlInputs().Length); | |||
EXPECT_EQ(0, neg.NumControlOutputs); | |||
EXPECT_EQ(0, neg.GetControlOutputs().Length); | |||
// Import it again, with an input mapping, return outputs, and a return | |||
// operation, into the same graph. | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
@@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
Operation feed2 = graph.OperationByName("imported2/feed"); | |||
Operation neg2 = graph.OperationByName("imported2/neg"); | |||
@@ -287,7 +287,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
// Check return operation | |||
var return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, return_opers.Length); | |||
@@ -302,26 +302,26 @@ namespace TensorFlowNET.UnitTest | |||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
var scalar3 = graph.OperationByName("imported3/scalar"); | |||
var feed3 = graph.OperationByName("imported3/feed"); | |||
var neg3 = graph.OperationByName("imported3/neg"); | |||
ASSERT_TRUE(scalar3 != IntPtr.Zero); | |||
ASSERT_TRUE(feed3 != IntPtr.Zero); | |||
ASSERT_TRUE(neg3 != IntPtr.Zero); | |||
// Check that newly-imported scalar and feed have control deps (neg3 will | |||
// inherit them from input) | |||
var control_inputs = scalar3.GetControlInputs(); | |||
ASSERT_EQ(2, scalar3.NumControlInputs); | |||
EXPECT_EQ(feed, control_inputs[0]); | |||
EXPECT_EQ(feed2, control_inputs[1]); | |||
control_inputs = feed3.GetControlInputs(); | |||
ASSERT_EQ(2, feed3.NumControlInputs); | |||
EXPECT_EQ(feed, control_inputs[0]); | |||
EXPECT_EQ(feed2, control_inputs[1]); | |||
// Export to a graph def so we can import a graph with control dependencies | |||
graph_def.Dispose(); | |||
graph_def = new Buffer(); | |||
@@ -51,18 +51,6 @@ namespace TensorFlowNET.UnitTest | |||
return def; | |||
} | |||
public static NodeDef GetNodeDef(Operation oper) | |||
{ | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_OperationToNodeDef(oper, buffer, s); | |||
s.Check(); | |||
var ret = NodeDef.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
s.Dispose(); | |||
return ret; | |||
} | |||
public static bool IsAddN(NodeDef node_def, int n) | |||
{ | |||
if (node_def.Op != "AddN" || node_def.Name != "add" || | |||