diff --git a/README.md b/README.md index f1aa73f7..4d7d81d2 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso TensorFlow.NET is a member project of SciSharp stack. +![tensors_flowing](docs/assets/tensors_flowing.gif) + ### How to use ```cs using tf = TensorFlowNET.Core.Tensorflow; @@ -14,7 +16,7 @@ namespace TensorFlowNET.Examples { public void Run() { - + var hello = tf.constant("Hello, TensorFlow!"); } } } diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 67c98366..696cf38a 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -14,10 +14,16 @@ namespace TensorFlowNET.Core public class Graph { public IntPtr handle; + private Dictionary _nodes_by_id; + private Dictionary _nodes_by_name; + public int _version; + private int _next_id_counter; public Graph(IntPtr graph) { this.handle = graph; + _nodes_by_id = new Dictionary(); + _nodes_by_name = new Dictionary(); } public unsafe Operation create_op(object inputs, string op_type = "", string name = "") @@ -28,8 +34,26 @@ namespace TensorFlowNET.Core } var op = new Operation(this, inputs); + op.name = name; return op; } + + public void _add_op(Operation op) + { + _nodes_by_id[op._id] = op; + //_nodes_by_name[op.name] = op; + _version = Math.Max(_version, op._id); + } + + public int _next_id() + { + return ++_next_id_counter; + } + + public void get_operations() + { + + } } } diff --git a/src/TensorFlowNET.Core/Operation.cs b/src/TensorFlowNET.Core/Operation.cs index 78da3d47..ecfbe5e0 100644 --- a/src/TensorFlowNET.Core/Operation.cs +++ b/src/TensorFlowNET.Core/Operation.cs @@ -8,12 +8,17 @@ namespace TensorFlowNET.Core { private Graph _graph; private IntPtr _c_op; + public int _id => _id_value; + private int _id_value; + public string name; public Operation(Graph g, object inputs) { _graph = g; + _id_value = _graph._next_id(); _c_op = ops._create_c_op(g, inputs); + _graph._add_op(this); } } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index e00c5bb7..b3303ba4 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -31,7 +31,7 @@ namespace TensorFlowNET.Core public static unsafe extern TF_Status TF_NewStatus(); [DllImport(TensorFlowLibName)] - public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); + public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 7e96f365..252fdbf6 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.Core case double value: var v = (double*)Marshal.AllocHGlobal(sizeof(double)); *v = value; - tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, IntPtr.Zero, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); break; }