diff --git a/README.md b/README.md index a0004962..bef21312 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) ![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/tensorflow-net-p7kmsjyo10ey?svg=true) -TensorFlow.NET is a member project of SciSharp stack. +TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) stack. ![tensors_flowing](docs/assets/tensors_flowing.gif) @@ -45,3 +45,5 @@ using(var sess = tf.Session()) var o = sess.run(c, feed_dict); } ``` + +Star me or raise issue on [Github](https://github.com/SciSharp/TensorFlow.NET) feel free. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a27d031e..1fbc3789 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -13,7 +13,7 @@ namespace Tensorflow /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// https://www.tensorflow.org/guide/graphs /// - public class Graph + public class Graph : IDisposable { private IntPtr _handle; private Dictionary _nodes_by_id; @@ -25,6 +25,11 @@ namespace Tensorflow private string _name_stack; + public Graph() + { + _handle = c_api.TF_NewGraph(); + } + public Graph(IntPtr graph) { _handle = graph; @@ -171,6 +176,11 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } + public void Dispose() + { + c_api.TF_DeleteGraph(_handle); + } + public static implicit operator IntPtr(Graph graph) { return graph._handle; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 21900f15..8578f33a 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -7,6 +7,14 @@ namespace Tensorflow { public static partial class c_api { + /// + /// Destroy an options object. Graph will be deleted once no more + /// TFSession's are referencing it. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteGraph(IntPtr graph); + [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); @@ -21,14 +29,14 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); + public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); /// /// Sets the shape of the Tensor referenced by `output` in `graph` to /// the shape described by `dims` and `num_dims`. /// [DllImport(TensorFlowLibName)] - public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); + public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); /// /// Returns the number of dimensions of the Tensor referenced by `output` diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index f89a16bc..1ab17cf8 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -14,16 +14,16 @@ namespace Tensorflow private Status status = new Status(); - public string name => c_api.TF_OperationName(_handle); - public string optype => c_api.TF_OperationOpType(_handle); - public string device => c_api.TF_OperationDevice(_handle); - public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); - public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); - public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); - public int NumInputs => c_api.TF_OperationNumInputs(_handle); - public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); - public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); - public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); + public string name { get; } + public string optype { get; } + public string device { get; } + public int NumOutputs { get; } + public TF_DataType OutputType { get; } + public int OutputListLength { get; } + public int NumInputs { get; } + public int NumConsumers { get; } + public int NumControlInputs { get; } + public int NumControlOutputs { get; } private Tensor[] _outputs; public Tensor[] outputs => _outputs; @@ -31,7 +31,21 @@ namespace Tensorflow public Operation(IntPtr handle) { + if (handle == IntPtr.Zero) + return; + _handle = handle; + + name = c_api.TF_OperationName(_handle); + optype = c_api.TF_OperationOpType(_handle); + device = "";// c_api.TF_OperationDevice(_handle); + NumOutputs = c_api.TF_OperationNumOutputs(_handle); + OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); + OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); + NumInputs = c_api.TF_OperationNumInputs(_handle); + NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); + NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); + NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); } public Operation(Graph g, string opType, string oper_name) diff --git a/src/TensorFlowNET.Core/Operations/TF_Output.cs b/src/TensorFlowNET.Core/Operations/TF_Output.cs index 16d0285a..76edd6bb 100644 --- a/src/TensorFlowNET.Core/Operations/TF_Output.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Output.cs @@ -14,7 +14,7 @@ namespace Tensorflow this.index = index; } - public IntPtr oper; + public unsafe IntPtr oper; public int index; } } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index b97e620d..bff7fd0c 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -22,6 +22,15 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_AddInput(IntPtr desc, TF_Output input); + /// + /// For inputs that take a list of tensors. + /// inputs must point to TF_Output[num_inputs]. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs); + [DllImport(TensorFlowLibName)] public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 56748f3a..4c4f19ff 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -11,7 +11,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public class Tensor + public class Tensor : IDisposable { private readonly IntPtr _handle; @@ -38,6 +38,7 @@ namespace Tensorflow /// n n-Tensor (you get the idea) /// public int rank; + public int NDims => rank; /// /// if original buffer is free. @@ -96,7 +97,7 @@ namespace Tensorflow nd.shape.Select(x => (long)x).ToArray(), // shape nd.ndim, dotHandle, - (UIntPtr)(nd.size * nd.dtypesize), + (ulong)(nd.size * nd.dtypesize), (IntPtr values, IntPtr len, ref bool closure) => { // Free the original buffer and set flag @@ -160,9 +161,19 @@ namespace Tensorflow return TF_DataType.DtInvalid; } + public void Dispose() + { + c_api.TF_DeleteTensor(_handle); + } + public static implicit operator IntPtr(Tensor tensor) { return tensor._handle; } + + public static implicit operator Tensor(IntPtr handle) + { + return new Tensor(handle); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index a99d14bb..741899d9 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -7,20 +7,23 @@ namespace Tensorflow { public static partial class c_api { + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); + /// /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. /// /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); + public static extern ulong TF_DataTypeSize(TF_DataType dt); /// /// Destroy a tensor. /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteTensor(IntPtr tensor); + public static extern void TF_DeleteTensor(IntPtr tensor); /// /// Return the length of the tensor in the "dim_index" dimension. @@ -30,7 +33,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index); + public static extern long TF_Dim(IntPtr tensor, int dim_index); /// /// Return a new tensor that holds the bytes data[0,len-1] @@ -44,7 +47,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); + public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); /// /// Return the number of dimensions that the tensor has. @@ -52,7 +55,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_NumDims(IntPtr tensor); + public static extern int TF_NumDims(IntPtr tensor); /// /// Return the size of the underlying data in bytes. @@ -60,7 +63,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); + public static extern ulong TF_TensorByteSize(IntPtr tensor); /// /// Return a pointer to the underlying data buffer. @@ -68,7 +71,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); + public static extern IntPtr TF_TensorData(IntPtr tensor); /// /// Return the type of a tensor element. @@ -76,6 +79,6 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); + public static extern TF_DataType TF_TensorType(IntPtr tensor); } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 94e530c0..dc7c3927 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -7,6 +7,7 @@ namespace Tensorflow { /// /// C API for TensorFlow. + /// Port from tensorflow\c\c_api.h /// /// The API leans towards simplicity and uniformity instead of convenience /// since most usage will be by language specific wrappers. diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 30c08425..708e44ae 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -9,17 +9,21 @@ namespace TensorFlowNET.UnitTest [TestClass] public class GraphTest { + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, Graph)` + /// [TestMethod] - public void Graph() + public void c_api_Graph() { var s = new Status(); - var graph = tf.get_default_graph(); + var graph = new Graph(); // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); Assert.AreEqual("feed", feed.name); Assert.AreEqual("Placeholder", feed.optype); - //Assert.AreEqual("", feed.device); + Assert.AreEqual("", feed.device); Assert.AreEqual(1, feed.NumOutputs); Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); Assert.AreEqual(1, feed.OutputListLength); @@ -30,6 +34,19 @@ namespace TensorFlowNET.UnitTest AttrValue attr_value = null; c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); + Assert.AreEqual(attr_value.Type, DataType.DtInt32); + + // Test not found errors in TF_Operation*() query functions. + // Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); + // Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); + // Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); + // Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); + + // Make a constant oper with the scalar "3". + var three = c_test_util.ScalarConst(3, graph, s); + + // Add oper. + var add = c_test_util.Add(feed, three, graph, s); } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 7c3907dc..c45a146e 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -10,8 +10,12 @@ namespace TensorFlowNET.UnitTest [TestClass] public class OperationsTest { + /// + /// Port from tensorflow\c\c_api_test.cc + /// `TEST(CAPI, GetAllOpList)` + /// [TestMethod] - public void GetAllOpList() + public void c_api_GetAllOpList() { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 7b083b16..4737f953 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -12,8 +12,29 @@ namespace TensorFlowNET.UnitTest [TestClass] public class TensorTest { + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, AllocateTensor)` + /// + [TestMethod] + public void c_api_AllocateTensor() + { + ulong num_bytes = 6 * sizeof(float); + long[] dims = { 2, 3 }; + Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); + Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); + Assert.AreEqual(2, t.NDims); + Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); + Assert.AreEqual(num_bytes, t.bytesize); + t.Dispose(); + } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, Tensor)` + /// [TestMethod] - public void NewTensor() + public void c_api_Tensor() { var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); @@ -30,46 +51,38 @@ namespace TensorFlowNET.UnitTest /// /// Port from tensorflow\c\c_api_test.cc + /// `TEST(CAPI, SetShape)` /// [TestMethod] - public void SetShape() + public void c_api_SetShape() { var s = new Status(); - var graph = tf.get_default_graph(); + var graph = new Graph(); - var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); - c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); - //if (!dims.empty()) - { - //TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); - } - var op = c_api.TF_FinishOperation(desc, s); - - Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.IsNotNull(op); + var feed = c_test_util.Placeholder(graph, s); + var feed_out_0 = new TF_Output(feed, 0); // Fetch the shape, it should be completely unknown. - var feed_out_0 = new TF_Output { oper = op, index = 0 }; int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.AreEqual(-1, num_dims); // Set the shape to be unknown, expect no change. - c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); Assert.AreEqual(-1, num_dims); // Set the shape to be 2 x Unknown - var dims = new int[] { 2, -1 }; + long[] dims = { 2, -1 }; c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); Assert.AreEqual(2, num_dims); // Get the dimension vector appropriately. - var returned_dims = new int[dims.Length]; + var returned_dims = new long[dims.Length]; c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); @@ -77,19 +90,57 @@ namespace TensorFlowNET.UnitTest // Set to a new valid shape: [2, 3] dims[1] = 3; c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); - //Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(s.Code == TF_Code.TF_OK); // Fetch and see that the new value is returned. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); - //Assert.IsTrue(s.Code == TF_Code.TF_OK); - //Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.AreEqual(2, num_dims); + Assert.AreEqual(2, returned_dims[0]); + Assert.AreEqual(3, returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that + // it doesn't change. + dims[0] = -1; + dims[1] = -1; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.AreEqual(2, num_dims); + Assert.AreEqual(2, returned_dims[0]); + Assert.AreEqual(3, returned_dims[1]); + + // Try to fetch a shape with the wrong num_dims + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + // Try to set an invalid shape (cannot change 2x3 to a 2x5). + dims[1] = 5; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); // Test for a scalar. var three = c_test_util.ScalarConst(3, graph, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - var three_out_0 = new TF_Output { oper = three }; + var three_out_0 = new TF_Output(three, 0); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.AreEqual(0, num_dims); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); + //Assert.IsTrue(s.Code == TF_Code.TF_OK); + + graph.Dispose(); + s.Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 744afe4e..fa845147 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -7,8 +7,32 @@ using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest { + /// + /// Port from `tensorflow\c\c_test_util.cc` + /// public static class c_test_util { + public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") + { + Operation op = null; + AddOpHelper(l, r, graph, s, name, ref op, true); + return op; + } + + public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) + { + var desc = c_api.TF_NewOperation(graph, "AddN", name); + + c_api.TF_AddInputList(desc, new TF_Output[] + { + new TF_Output(l, 0), + new TF_Output(r, 0), + }, 2); + + op = c_api.TF_FinishOperation(desc, s); + s.Check(); + } + public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { var buffer = new Buffer(); @@ -58,7 +82,7 @@ namespace TensorFlowNET.UnitTest return op; } - public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") + public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") { return Const(new Tensor(v), graph, s, name); }