@@ -4,6 +4,8 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||||
TensorFlow.NET is a member project of SciSharp stack. | TensorFlow.NET is a member project of SciSharp stack. | ||||
 | |||||
### How to use | ### How to use | ||||
```cs | ```cs | ||||
using tf = TensorFlowNET.Core.Tensorflow; | using tf = TensorFlowNET.Core.Tensorflow; | ||||
@@ -14,7 +16,7 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
public void Run() | public void Run() | ||||
{ | { | ||||
var hello = tf.constant("Hello, TensorFlow!"); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -14,10 +14,16 @@ namespace TensorFlowNET.Core | |||||
public class Graph | public class Graph | ||||
{ | { | ||||
public IntPtr handle; | public IntPtr handle; | ||||
private Dictionary<int, Operation> _nodes_by_id; | |||||
private Dictionary<string, Operation> _nodes_by_name; | |||||
public int _version; | |||||
private int _next_id_counter; | |||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
this.handle = graph; | this.handle = graph; | ||||
_nodes_by_id = new Dictionary<int, Operation>(); | |||||
_nodes_by_name = new Dictionary<string, Operation>(); | |||||
} | } | ||||
public unsafe Operation create_op(object inputs, string op_type = "", string name = "") | public unsafe Operation create_op(object inputs, string op_type = "", string name = "") | ||||
@@ -28,8 +34,26 @@ namespace TensorFlowNET.Core | |||||
} | } | ||||
var op = new Operation(this, inputs); | var op = new Operation(this, inputs); | ||||
op.name = name; | |||||
return op; | return op; | ||||
} | } | ||||
public void _add_op(Operation op) | |||||
{ | |||||
_nodes_by_id[op._id] = op; | |||||
//_nodes_by_name[op.name] = op; | |||||
_version = Math.Max(_version, op._id); | |||||
} | |||||
public int _next_id() | |||||
{ | |||||
return ++_next_id_counter; | |||||
} | |||||
public void get_operations() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -8,12 +8,17 @@ namespace TensorFlowNET.Core | |||||
{ | { | ||||
private Graph _graph; | private Graph _graph; | ||||
private IntPtr _c_op; | private IntPtr _c_op; | ||||
public int _id => _id_value; | |||||
private int _id_value; | |||||
public string name; | |||||
public Operation(Graph g, object inputs) | public Operation(Graph g, object inputs) | ||||
{ | { | ||||
_graph = g; | _graph = g; | ||||
_id_value = _graph._next_id(); | |||||
_c_op = ops._create_c_op(g, inputs); | _c_op = ops._create_c_op(g, inputs); | ||||
_graph._add_op(this); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -31,7 +31,7 @@ namespace TensorFlowNET.Core | |||||
public static unsafe extern TF_Status TF_NewStatus(); | public static unsafe extern TF_Status TF_NewStatus(); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | ||||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.Core | |||||
case double value: | case double value: | ||||
var v = (double*)Marshal.AllocHGlobal(sizeof(double)); | var v = (double*)Marshal.AllocHGlobal(sizeof(double)); | ||||
*v = value; | *v = value; | ||||
tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, IntPtr.Zero, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); | |||||
tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); | |||||
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); | c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); | ||||
break; | break; | ||||
} | } | ||||