diff --git a/src/TensorFlowNET.Core/BaseSession.cs b/src/TensorFlowNET.Core/BaseSession.cs new file mode 100644 index 00000000..6b44b28e --- /dev/null +++ b/src/TensorFlowNET.Core/BaseSession.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace TensorFlowNET.Core +{ + public class BaseSession + { + private Graph _graph; + private bool _opened; + private bool _closed; + private int _current_version; + private byte[] _target; + private IntPtr _session; + + public BaseSession(string target = "", Graph graph = null) + { + if(graph is null) + { + _graph = ops.get_default_graph(); + } + else + { + _graph = graph; + } + + _target = UTF8Encoding.UTF8.GetBytes(target); + var opts = c_api.TF_NewSessionOptions(); + var status = new Status(); + _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); + + c_api.TF_DeleteSessionOptions(opts); + } + + public virtual byte[] run(Tensor fetches) + { + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 5087f6f6..216db8f6 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -16,7 +16,8 @@ namespace TensorFlowNET.Core /// public class Graph { - public IntPtr handle; + private IntPtr _c_graph; + public IntPtr Handle => _c_graph; private Dictionary _nodes_by_id; private Dictionary _nodes_by_name; private Dictionary _names_in_use; @@ -25,7 +26,7 @@ namespace TensorFlowNET.Core public Graph(IntPtr graph) { - this.handle = graph; + this._c_graph = graph; _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); diff --git a/src/TensorFlowNET.Core/Session.cs b/src/TensorFlowNET.Core/Session.cs index b8ae6544..0755539f 100644 --- a/src/TensorFlowNET.Core/Session.cs +++ b/src/TensorFlowNET.Core/Session.cs @@ -4,7 +4,13 @@ using System.Text; namespace TensorFlowNET.Core { - public class Session + public class Session : BaseSession { + public override byte[] run(Tensor fetches) + { + var ret = base.run(fetches); + + return ret; + } } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 62d4d654..ad052e23 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -9,6 +9,8 @@ using TF_OperationDescription = System.IntPtr; using TF_Operation = System.IntPtr; using TF_Status = System.IntPtr; using TF_Tensor = System.IntPtr; +using TF_Session = System.IntPtr; +using TF_SessionOptions = System.IntPtr; using TF_DataType = Tensorflow.DataType; using Tensorflow; @@ -20,6 +22,9 @@ namespace TensorFlowNET.Core { public const string TensorFlowLibName = "tensorflow"; + [DllImport(TensorFlowLibName)] + public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); + [DllImport(TensorFlowLibName)] public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); @@ -53,6 +58,12 @@ namespace TensorFlowNET.Core [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); + [DllImport(TensorFlowLibName)] + public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); + + [DllImport(TensorFlowLibName)] + public static extern TF_SessionOptions TF_NewSessionOptions(); + [DllImport(TensorFlowLibName)] public static unsafe extern IntPtr TF_Version(); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index a2477fee..59a2d682 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -20,7 +20,7 @@ namespace TensorFlowNET.Core public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) { - var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name); + var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); var status = new Status(); foreach (var attr in node_def.Attr) diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index 31328bb8..b3991160 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples // Start tf session var sess = tf.Session(); + + // Run the op + sess.run(hello); } } }