diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 3ab9d6f9..550925e6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; namespace Tensorflow @@ -14,16 +15,16 @@ namespace Tensorflow private Status status = new Status(); - public string name { get; } - public string optype { get; } - public string device { get; } - public int NumOutputs { get; } - public TF_DataType OutputType { get; } - public int OutputListLength { get; } - public int NumInputs { get; } - public int NumConsumers { get; } - public int NumControlInputs { get; } - public int NumControlOutputs { get; } + public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); + public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); + public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); + public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); + public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); + public int NumInputs => c_api.TF_OperationNumInputs(_handle); + public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); + public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); + public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); private Tensor[] _outputs; public Tensor[] outputs => _outputs; @@ -35,17 +36,6 @@ namespace Tensorflow return; _handle = handle; - - name = c_api.TF_OperationName(_handle); - optype = c_api.TF_OperationOpType(_handle); - device = "";// c_api.TF_OperationDevice(_handle); - NumOutputs = c_api.TF_OperationNumOutputs(_handle); - OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); - OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); - NumInputs = c_api.TF_OperationNumInputs(_handle); - NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); - NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); - NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); } public Operation(Graph g, string opType, string oper_name) @@ -62,8 +52,8 @@ namespace Tensorflow Graph = g; _id_value = Graph._next_id(); + _handle = ops._create_c_op(g, node_def, inputs); - NumOutputs = c_api.TF_OperationNumOutputs(_handle); _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index cfde53c1..7bd829f1 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -38,7 +38,7 @@ namespace Tensorflow public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); [DllImport(TensorFlowLibName)] - public static extern string TF_OperationDevice(IntPtr oper); + public static extern IntPtr TF_OperationDevice(IntPtr oper); /// /// Sets `output_attr_value` to the binary-serialized AttrValue proto @@ -50,13 +50,13 @@ namespace Tensorflow public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); [DllImport(TensorFlowLibName)] - public static extern string TF_OperationName(IntPtr oper); + public static extern IntPtr TF_OperationName(IntPtr oper); [DllImport(TensorFlowLibName)] public static extern int TF_OperationNumInputs(IntPtr oper); [DllImport(TensorFlowLibName)] - public static extern string TF_OperationOpType(IntPtr oper); + public static extern IntPtr TF_OperationOpType(IntPtr oper); /// /// Get the number of control inputs to an operation. diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index f0eefdda..85ef04ea 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -30,12 +30,12 @@ namespace Tensorflow // Add inputs if(inputs != null && inputs.Count > 0) { - /*foreach (var op_input in inputs) + foreach (var op_input in inputs) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); - }*/ + } - c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); + //c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); } var status = new Status(); diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index e0d3edca..ec1c017f 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -15,7 +15,7 @@ namespace Tensorflow /// /// Error message /// - public string Message => c_api.TF_Message(_handle); + public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); /// /// Error code diff --git a/src/TensorFlowNET.Core/Status/c_api.status.cs b/src/TensorFlowNET.Core/Status/c_api.status.cs index efd2a959..5ba62136 100644 --- a/src/TensorFlowNET.Core/Status/c_api.status.cs +++ b/src/TensorFlowNET.Core/Status/c_api.status.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteStatus(IntPtr s); + public static extern void TF_DeleteStatus(IntPtr s); /// /// Return the code record in *s. @@ -20,7 +20,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe TF_Code TF_GetCode(IntPtr s); + public static extern TF_Code TF_GetCode(IntPtr s); /// /// Return a pointer to the (null-terminated) error message in *s. @@ -30,7 +30,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe string TF_Message(IntPtr s); + public static extern IntPtr TF_Message(IntPtr s); /// /// Return a new status object. diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index dc7c3927..70d933a3 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -12,20 +12,27 @@ namespace Tensorflow /// The API leans towards simplicity and uniformity instead of convenience /// since most usage will be by language specific wrappers. /// - /// The params type mapping between .net and c_api + /// The params type mapping between c_api and .NET /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) /// struct => struct (TF_Output output) => (TF_Output output) + /// struct* => struct (TF_Output* output) => (TF_Output[] output) /// const char* => string /// int32_t => int /// int64_t* => long[] /// size_t* => unlong[] /// void* => IntPtr + /// string => IntPtr c_api.StringPiece(IntPtr) /// public static partial class c_api { public const string TensorFlowLibName = "tensorflow"; + public static string StringPiece(IntPtr handle) + { + return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); + } + public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); [DllImport(TensorFlowLibName)] diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 2e860410..76bbd28e 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -37,7 +37,7 @@ namespace Tensorflow context.default_execution_mode = Context.EAGER_MODE; } - public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version()); + public static string VERSION => c_api.StringPiece(c_api.TF_Version()); public static Graph get_default_graph() { diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 9b95820d..9dab6da7 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -39,17 +39,14 @@ namespace TensorFlowNET.UnitTest // Test not found errors in TF_Operation*() query functions. Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); - //Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); - //Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message); + Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); + Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); // Add oper. var add = c_test_util.Add(feed, three, graph, s); - - NodeDef node_def = null; - c_test_util.GetNodeDef(feed, ref node_def); } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 24c0e701..c45a146e 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest public void addInConstant() { var a = tf.constant(4.0f); - var b = tf.placeholder(tf.float32); + var b = tf.constant(5.0f); var c = tf.add(a, b); using (var sess = tf.Session()) diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 489226b1..f62433ae 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -23,11 +23,13 @@ namespace TensorFlowNET.UnitTest { var desc = c_api.TF_NewOperation(graph, "AddN", name); - c_api.TF_AddInputList(desc, new TF_Output[] + var inputs = new TF_Output[] { new TF_Output(l, 0), new TF_Output(r, 0), - }, 2); + }; + + c_api.TF_AddInputList(desc, inputs, inputs.Length); op = c_api.TF_FinishOperation(desc, s); s.Check();