@@ -0,0 +1,40 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace TensorFlowNET.Core | |||||
{ | |||||
public class BaseSession | |||||
{ | |||||
private Graph _graph; | |||||
private bool _opened; | |||||
private bool _closed; | |||||
private int _current_version; | |||||
private byte[] _target; | |||||
private IntPtr _session; | |||||
public BaseSession(string target = "", Graph graph = null) | |||||
{ | |||||
if(graph is null) | |||||
{ | |||||
_graph = ops.get_default_graph(); | |||||
} | |||||
else | |||||
{ | |||||
_graph = graph; | |||||
} | |||||
_target = UTF8Encoding.UTF8.GetBytes(target); | |||||
var opts = c_api.TF_NewSessionOptions(); | |||||
var status = new Status(); | |||||
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); | |||||
c_api.TF_DeleteSessionOptions(opts); | |||||
} | |||||
public virtual byte[] run(Tensor fetches) | |||||
{ | |||||
return null; | |||||
} | |||||
} | |||||
} |
@@ -16,7 +16,8 @@ namespace TensorFlowNET.Core | |||||
/// </summary> | /// </summary> | ||||
public class Graph | public class Graph | ||||
{ | { | ||||
public IntPtr handle; | |||||
private IntPtr _c_graph; | |||||
public IntPtr Handle => _c_graph; | |||||
private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
private Dictionary<string, Operation> _nodes_by_name; | private Dictionary<string, Operation> _nodes_by_name; | ||||
private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
@@ -25,7 +26,7 @@ namespace TensorFlowNET.Core | |||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
this.handle = graph; | |||||
this._c_graph = graph; | |||||
_nodes_by_id = new Dictionary<int, Operation>(); | _nodes_by_id = new Dictionary<int, Operation>(); | ||||
_nodes_by_name = new Dictionary<string, Operation>(); | _nodes_by_name = new Dictionary<string, Operation>(); | ||||
_names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
@@ -4,7 +4,13 @@ using System.Text; | |||||
namespace TensorFlowNET.Core | namespace TensorFlowNET.Core | ||||
{ | { | ||||
public class Session | |||||
public class Session : BaseSession | |||||
{ | { | ||||
public override byte[] run(Tensor fetches) | |||||
{ | |||||
var ret = base.run(fetches); | |||||
return ret; | |||||
} | |||||
} | } | ||||
} | } |
@@ -9,6 +9,8 @@ using TF_OperationDescription = System.IntPtr; | |||||
using TF_Operation = System.IntPtr; | using TF_Operation = System.IntPtr; | ||||
using TF_Status = System.IntPtr; | using TF_Status = System.IntPtr; | ||||
using TF_Tensor = System.IntPtr; | using TF_Tensor = System.IntPtr; | ||||
using TF_Session = System.IntPtr; | |||||
using TF_SessionOptions = System.IntPtr; | |||||
using TF_DataType = Tensorflow.DataType; | using TF_DataType = Tensorflow.DataType; | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -20,6 +22,9 @@ namespace TensorFlowNET.Core | |||||
{ | { | ||||
public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
[DllImport(TensorFlowLibName)] | |||||
public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | ||||
@@ -53,6 +58,12 @@ namespace TensorFlowNET.Core | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TF_SessionOptions TF_NewSessionOptions(); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern IntPtr TF_Version(); | public static unsafe extern IntPtr TF_Version(); | ||||
} | } | ||||
@@ -20,7 +20,7 @@ namespace TensorFlowNET.Core | |||||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | ||||
{ | { | ||||
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name); | |||||
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | |||||
var status = new Status(); | var status = new Status(); | ||||
foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
@@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples | |||||
// Start tf session | // Start tf session | ||||
var sess = tf.Session(); | var sess = tf.Session(); | ||||
// Run the op | |||||
sess.run(hello); | |||||
} | } | ||||
} | } | ||||
} | } |