diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index f17551a1..a8c6563a 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -13,13 +13,13 @@ namespace Tensorflow var num_return_outputs = opts.NumReturnOutputs; var return_outputs = new TF_Output[num_return_outputs]; int size = Marshal.SizeOf(); - TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs); + var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); for (int i = 0; i < num_return_outputs; i++) { var handle = return_output_handle + i * size; - return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); + return_outputs[i] = Marshal.PtrToStructure(handle); } return return_outputs; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index b2988a22..6e7a5bb3 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -61,7 +61,7 @@ namespace Tensorflow /// int /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); + public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, IntPtr status); /// /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index ab17c7f1..f05db687 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -9,12 +9,12 @@ namespace Tensorflow { public class BaseSession : IDisposable { - private Graph _graph; - private bool _opened; - private bool _closed; - private int _current_version; - private byte[] _target; - private IntPtr _session; + protected Graph _graph; + protected bool _opened; + protected bool _closed; + protected int _current_version; + protected byte[] _target; + protected IntPtr _session; public BaseSession(string target = "", Graph graph = null) { @@ -28,11 +28,11 @@ namespace Tensorflow } _target = UTF8Encoding.UTF8.GetBytes(target); - //var opts = c_api.TF_NewSessionOptions(); - //var status = new Status(); - //_session = c_api.TF_NewSession(_graph, opts, status); + var opts = c_api.TF_NewSessionOptions(); + var status = new Status(); + _session = c_api.TF_NewSession(_graph, opts, status); - //c_api.TF_DeleteSessionOptions(opts); + c_api.TF_DeleteSessionOptions(opts); } public void Dispose() @@ -102,18 +102,18 @@ namespace Tensorflow var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - /*c_api.TF_SessionRun(_session, - run_options: IntPtr.Zero, + c_api.TF_SessionRun(_session, + run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), input_values: new IntPtr[] { }, ninputs: 0, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, - target_opers: new IntPtr[] { }, + target_opers: IntPtr.Zero, ntargets: 0, run_metadata: IntPtr.Zero, - status: status);*/ + status: status); var result = output_values.Select(x => c_api.TF_TensorData(x)) .Select(x => (object)*(float*)x) diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 936cf99b..32801444 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -8,6 +8,10 @@ namespace Tensorflow { private IntPtr _handle; + public Session(string target = "", Graph graph = null) + { + } + public Session(IntPtr handle) { _handle = handle; diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 29118088..894f0598 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -70,10 +70,18 @@ namespace Tensorflow /// int /// TF_Buffer* /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, + TF_Output[] inputs, IntPtr[] input_values, int ninputs, + TF_Output[] outputs, IntPtr[] output_values, int noutputs, + IntPtr target_opers, int ntargets, + IntPtr run_metadata, + IntPtr status); + [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, IntPtr inputs, IntPtr input_values, int ninputs, - IntPtr outputs, ref IntPtr output_values, int noutputs, + IntPtr outputs, IntPtr[] output_values, int noutputs, IntPtr target_opers, int ntargets, IntPtr run_metadata, IntPtr status); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 54813168..b7344586 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -33,7 +33,7 @@ namespace Tensorflow { var dims = new long[rank]; for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(_handle, i); + dims[i] = c_api.TF_Dim(_handle, i); return dims; } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 8c643597..76bbd28e 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -51,7 +51,7 @@ namespace Tensorflow public static Session Session() { - return (Session)new BaseSession(); + return new Session(); } } } diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index 6678dcf4..6be4c7c0 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -79,17 +79,17 @@ namespace TensorFlowNET.UnitTest IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; - IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; + IntPtr[] output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; IntPtr targets_ptr = IntPtr.Zero; - c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count, - outputs_ptr, ref output_values_ptr, outputs_.Count, + c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 0, + outputs_ptr, output_values_ptr, outputs_.Count, targets_ptr, targets_.Count, IntPtr.Zero, s); s.Check(); - output_values_[0] = output_values_ptr; + output_values_[0] = output_values_ptr[0]; } public IntPtr output_tensor(int i) diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 30066676..c8843812 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -77,7 +77,7 @@ namespace TensorFlowNET.UnitTest ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); EXPECT_EQ(DataType.DtInt32, attr_value.Type); ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); - EXPECT_EQ(2, attr_value.I); + EXPECT_EQ(2, (int)attr_value.I); // Placeholder oper now has a consumer. EXPECT_EQ(1, feed.OutputNumConsumers(0)); @@ -353,7 +353,7 @@ namespace TensorFlowNET.UnitTest c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - //graph.Dispose(); + graph.Dispose(); s.Dispose(); } @@ -368,12 +368,12 @@ namespace TensorFlowNET.UnitTest var graph = new Graph(); // Create a graph with two nodes: x and 3 - var feed = c_test_util.Placeholder(graph, s); - EXPECT_EQ(feed, graph.OperationByName("feed")); - var scalar = c_test_util.ScalarConst(3, graph, s); - EXPECT_EQ(scalar, graph.OperationByName("scalar")); - var neg = c_test_util.Neg(scalar, graph, s); - EXPECT_EQ(neg, graph.OperationByName("neg")); + c_test_util.Placeholder(graph, s); + ASSERT_TRUE(graph.OperationByName("feed") != null); + var oper = c_test_util.ScalarConst(3, graph, s); + ASSERT_TRUE(graph.OperationByName("scalar") != null); + c_test_util.Neg(oper, graph, s); + ASSERT_TRUE(graph.OperationByName("neg") != null); // Export to a GraphDef. var graph_def = graph.ToGraphDef(s); @@ -389,9 +389,9 @@ namespace TensorFlowNET.UnitTest var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - scalar = graph.OperationByName("scalar"); - feed = graph.OperationByName("feed"); - neg = graph.OperationByName("neg"); + var scalar = graph.OperationByName("scalar"); + var feed = graph.OperationByName("feed"); + var neg = graph.OperationByName("neg"); ASSERT_TRUE(scalar != IntPtr.Zero); ASSERT_TRUE(feed != IntPtr.Zero); ASSERT_TRUE(neg != IntPtr.Zero); diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index aac36d06..2f6ac2d3 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -57,9 +57,9 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); - EXPECT_EQ(tensor.shape[0], nd.shape[0]); - EXPECT_EQ(tensor.shape[1], nd.shape[1]); - EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); + EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); + EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); + EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), array)); } @@ -118,8 +118,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, returned_dims[0]); - EXPECT_EQ(3, returned_dims[1]); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. @@ -130,8 +130,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, returned_dims[0]); - EXPECT_EQ(3, returned_dims[1]); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); // Try to fetch a shape with the wrong num_dims c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index c242299b..2ed657f8 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -13,13 +13,6 @@ namespace TensorFlowNET.UnitTest public static class c_test_util { public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") - { - Operation op = null; - AddOpHelper(l, r, graph, s, name, ref op, true); - return op; - } - - public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) { var desc = c_api.TF_NewOperation(graph, "AddN", name); @@ -31,8 +24,10 @@ namespace TensorFlowNET.UnitTest c_api.TF_AddInputList(desc, inputs, inputs.Length); - op = c_api.TF_FinishOperation(desc, s); + var op = c_api.TF_FinishOperation(desc, s); s.Check(); + + return op; } public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) @@ -196,39 +191,29 @@ namespace TensorFlowNET.UnitTest return op; } - public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) + public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) { var desc = c_api.TF_NewOperation(graph, "Placeholder", name); c_api.TF_SetAttrType(desc, "dtype", dtype); - if(dims != null) + if (dims != null) { c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); } - op = c_api.TF_FinishOperation(desc, s); + var op = c_api.TF_FinishOperation(desc, s); s.Check(); - } - public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) - { - Operation op = null; - PlaceholderHelper(graph, s, name, dtype, dims, ref op); return op; } - public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) + public static Operation Const(Tensor t, Graph graph, Status s, string name) { var desc = c_api.TF_NewOperation(graph, "Const", name); c_api.TF_SetAttrTensor(desc, "value", t, s); s.Check(); c_api.TF_SetAttrType(desc, "dtype", t.dtype); - op = c_api.TF_FinishOperation(desc, s); + var op = c_api.TF_FinishOperation(desc, s); s.Check(); - } - public static Operation Const(Tensor t, Graph graph, Status s, string name) - { - Operation op = null; - ConstHelper(t, graph, s, name, ref op); return op; } @@ -247,7 +232,7 @@ namespace TensorFlowNET.UnitTest (IntPtr values, IntPtr len, ref bool closure) => { // Free the original buffer and set flag - Marshal.FreeHGlobal(dotHandle); + // Marshal.FreeHGlobal(dotHandle); }, ref deallocator_called); } }