diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 48353916..a20fe727 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -24,23 +24,36 @@ namespace Tensorflow private List _unfetchable_ops = new List(); private string _name_stack; + public Status Status { get; } public Graph() { _handle = c_api.TF_NewGraph(); + Status = new Status(); } public Graph(IntPtr graph) { _handle = graph; + Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); } - public OperationDescription NewOperation(string opType, string opName) + public Operation NewOperation(string opType, string opName, Tensor t) { - return c_api.TF_NewOperation(_handle, opType, opName); + var desc = c_api.TF_NewOperation(_handle, opType, opName); + + c_api.TF_SetAttrTensor(desc, "value", t, Status); + Status.Check(); + + c_api.TF_SetAttrType(desc, "dtype", t.dtype); + + var op = c_api.TF_FinishOperation(desc, Status); + Status.Check(); + + return op; } public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index afdb886b..262fe531 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -7,6 +7,18 @@ namespace Tensorflow { public static partial class c_api { + /// + /// Request that `desc` be co-located on the device where `op` + /// is placed. + /// + /// Use of this is discouraged since the implementation of device placement is + /// subject to change. Primarily intended for internal libraries + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ColocateWith(IntPtr desc, IntPtr op); + /// /// Get the OpList of all OpDefs defined in this address space. /// @@ -209,7 +221,7 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); + public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length); /// /// diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index f05db687..d9739f7b 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -119,6 +119,9 @@ namespace Tensorflow .Select(x => (object)*(float*)x) .ToArray(); + var op = new Operation(fetch_list[0].oper); + //var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status); + return result; } diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 32801444..70644aac 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -7,9 +7,19 @@ namespace Tensorflow public class Session : BaseSession { private IntPtr _handle; + public Status Status { get; } + public SessionOptions Options { get; } public Session(string target = "", Graph graph = null) { + Status = new Status(); + if(graph == null) + { + graph = tf.get_default_graph(); + } + Options = new SessionOptions(); + _handle = c_api.TF_NewSession(graph, Options, Status); + Status.Check(); } public Session(IntPtr handle) diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index ec1c017f..a4648307 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -36,12 +36,15 @@ namespace Tensorflow /// Check status /// Throw exception with error message if code != TF_OK /// - public void Check() + public void Check(bool throwException = false) { if(Code != TF_Code.TF_OK) { Console.WriteLine(Message); - // throw new Exception(Message); + if (throwException) + { + throw new Exception(Message); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index b7344586..c321af8a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -69,6 +69,7 @@ namespace Tensorflow private IntPtr Allocate(NDArray nd) { var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); + ulong size = (ulong)(nd.size * nd.dtypesize); switch (nd.dtype.Name) { @@ -81,16 +82,21 @@ namespace Tensorflow case "Double": Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; + case "String": + dotHandle = Marshal.StringToHGlobalAuto(nd.Data()[0]); + size = (ulong)nd.Data()[0].Length; + break; default: throw new NotImplementedException("Marshal.Copy failed."); } var dataType = ToTFDataType(nd.dtype); + var tfHandle = c_api.TF_NewTensor(dataType, nd.shape.Select(x => (long)x).ToArray(), // shape nd.ndim, dotHandle, - (ulong)(nd.size * nd.dtypesize), + size, (IntPtr values, IntPtr len, ref bool closure) => { // Free the original buffer and set flag @@ -154,6 +160,8 @@ namespace Tensorflow return TF_DataType.TF_FLOAT; case "Double": return TF_DataType.TF_DOUBLE; + case "String": + return TF_DataType.TF_STRING; } return TF_DataType.DtInvalid; diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 9b9eba0b..3aa643d0 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -34,12 +34,13 @@ namespace Tensorflow attrs["dtype"] = dtype_value; attrs["value"] = tensor_value; - var const_tensor = g.create_op("Const", - null, - new TF_DataType[] { (TF_DataType)dtype_value.Type }, + var op = g.create_op("Const", + null, + new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs, - name: name).outputs[0]; + name: name); + var const_tensor = op.outputs[0]; const_tensor.value = nd.Data(); return const_tensor; diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index df56b835..d60fb50e 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -7,9 +7,9 @@ namespace Tensorflow { public static partial class tf { - public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false) + public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { - return constant_op.Create(value, name, verify_shape); + return constant_op.Create(nd, name, verify_shape); } } } diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index 2913c6b6..290abc45 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -24,7 +24,7 @@ namespace TensorFlowNET.Examples var sess = tf.Session(); // Run the op - sess.run(hello); + Console.WriteLine(sess.run(hello)); } } } diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 70adeb7d..6f5f8744 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples Console.ReadLine(); } } + + Console.ReadLine(); } } } diff --git a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs index a31ffbc1..8b90c669 100644 --- a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs @@ -25,12 +25,17 @@ namespace TensorFlowNET.UnitTest public void SetUp() { feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); + s_.Check(); feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); + s_.Check(); constant_ = c_test_util.ScalarConst(10, graph_, s_); - desc_ = graph_.NewOperation("AddN", "add"); + s_.Check(); + desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); + s_.Check(); TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; desc_.AddInputList(inputs); + s_.Check(); } private void SetViaStringList(OperationDescription desc, string[] list) @@ -85,7 +90,8 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void ColocateWith() { - + c_api.TF_ColocateWith(desc_, feed1_); + FinishAndVerify(desc_, new string[] { "loc:@feed1" }); } [TestMethod]