From 7b24f538fb2d99e74a17d83ca96229af01260fbd Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 13 Dec 2018 23:18:11 -0600 Subject: [PATCH] added TF_Status, TF_SetAttrValueProto --- .gitignore | 1 - src/TensorFlowNET.Core/Status.cs | 22 +++++++++ .../TensorFlowNET.Core.csproj | 2 +- src/TensorFlowNET.Core/Tensorflow/TF_Code.cs | 27 +++++++++++ src/TensorFlowNET.Core/c_api.cs | 10 ++++- src/TensorFlowNET.Core/ops.cs | 45 +++---------------- test/TensorFlowNET.Examples/HelloWorld.cs | 3 +- 7 files changed, 67 insertions(+), 43 deletions(-) create mode 100644 src/TensorFlowNET.Core/Status.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/TF_Code.cs diff --git a/.gitignore b/.gitignore index 6dcd3b51..1a6a75a2 100644 --- a/.gitignore +++ b/.gitignore @@ -333,5 +333,4 @@ ASALocalRun/ /tensorflowlib/osx/native/libtensorflow.dylib /tensorflowlib/linux/native/libtensorflow_framework.so /tensorflowlib/linux/native/libtensorflow.so -/src/TensorFlowNET.Core/libtensorflow.dll /src/TensorFlowNET.Core/tensorflow.dll diff --git a/src/TensorFlowNET.Core/Status.cs b/src/TensorFlowNET.Core/Status.cs new file mode 100644 index 00000000..d9bd90ea --- /dev/null +++ b/src/TensorFlowNET.Core/Status.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Core +{ + public class Status + { + private IntPtr _handle; + public IntPtr Handle => _handle; + + public string ErrorMessage => c_api.TF_Message(_handle); + + public TF_Code Code => c_api.TF_GetCode(_handle); + + public Status() + { + _handle = c_api.TF_NewStatus(); + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 929c7940..3b550231 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -22,7 +22,7 @@ - + PreserveNewest diff --git a/src/TensorFlowNET.Core/Tensorflow/TF_Code.cs b/src/TensorFlowNET.Core/Tensorflow/TF_Code.cs new file mode 100644 index 00000000..5e8e3c8d --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/TF_Code.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public enum TF_Code + { + TF_OK = 0, + TF_CANCELLED = 1, + TF_UNKNOWN = 2, + TF_INVALID_ARGUMENT = 3, + TF_DEADLINE_EXCEEDED = 4, + TF_NOT_FOUND = 5, + TF_ALREADY_EXISTS = 6, + TF_PERMISSION_DENIED = 7, + TF_UNAUTHENTICATED = 16, + TF_RESOURCE_EXHAUSTED = 8, + TF_FAILED_PRECONDITION = 9, + TF_ABORTED = 10, + TF_OUT_OF_RANGE = 11, + TF_UNIMPLEMENTED = 12, + TF_INTERNAL = 13, + TF_UNAVAILABLE = 14, + TF_DATA_LOSS = 15 + } +} diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 7e2e0251..62d4d654 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -11,7 +11,7 @@ using TF_Status = System.IntPtr; using TF_Tensor = System.IntPtr; using TF_DataType = Tensorflow.DataType; - +using Tensorflow; using static TensorFlowNET.Core.Tensorflow; namespace TensorFlowNET.Core @@ -23,6 +23,12 @@ namespace TensorFlowNET.Core [DllImport(TensorFlowLibName)] public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); + [DllImport(TensorFlowLibName)] + public static extern unsafe TF_Code TF_GetCode(TF_Status s); + + [DllImport(TensorFlowLibName)] + public static extern unsafe string TF_Message(TF_Status s); + [DllImport(TensorFlowLibName)] public static unsafe extern IntPtr TF_NewGraph(); @@ -39,7 +45,7 @@ namespace TensorFlowNET.Core public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, void* proto, size_t proto_len, TF_Status status); + public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, size_t proto_len, TF_Status status); [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 1cda3bad..a2477fee 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -7,6 +7,7 @@ using Tensorflow; using tf = TensorFlowNET.Core.Tensorflow; using TF_DataType = Tensorflow.DataType; using node_def_pb2 = Tensorflow; +using Google.Protobuf; namespace TensorFlowNET.Core { @@ -20,49 +21,17 @@ namespace TensorFlowNET.Core public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) { var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name); - var status = c_api.TF_NewStatus(); - - // Doesn't work - /*foreach(var attr in node_def.Attr) - { - if (attr.Value.Tensor != null) - { - switch (attr.Value.Tensor.Dtype) - { - case DataType.DtDouble: - var proto = (double*)Marshal.AllocHGlobal(sizeof(double)); - *proto = attr.Value.Tensor.DoubleVal[0]; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status); - break; - } - } - else - { - //c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status); - } - } */ + var status = new Status(); foreach (var attr in node_def.Attr) { - if (attr.Value.Tensor == null) continue; - switch (attr.Value.Tensor.Dtype) - { - case DataType.DtDouble: - var v = (double*)Marshal.AllocHGlobal(sizeof(double)); - *v = attr.Value.Tensor.DoubleVal[0]; - var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); - c_api.TF_SetAttrTensor(op_desc, "value", tensor, status); - c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble); - break; - case DataType.DtString: - - var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8()); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status); - break; - } + var bytes = attr.Value.ToByteArray(); + var proto = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, proto, bytes.Length); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); } - var c_op = c_api.TF_FinishOperation(op_desc, status); + var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); return c_op; } diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index fd935363..a32700e4 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -19,7 +19,8 @@ namespace TensorFlowNET.Examples The value returned by the constructor represents the output of the Constant op.*/ var graph = tf.get_default_graph(); - var hello = tf.constant("Hello, TensorFlow!"); + var hello = tf.constant(4.0); + //var hello = tf.constant("Hello, TensorFlow!"); // Start tf session // var sess = tf.Session();