@@ -4,7 +4,7 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class BaseSession | |||
public class BaseSession : IDisposable | |||
{ | |||
private Graph _graph; | |||
private bool _opened; | |||
@@ -32,18 +32,23 @@ namespace Tensorflow | |||
c_api.TF_DeleteSessionOptions(opts); | |||
} | |||
public virtual byte[] run(Tensor fetches) | |||
public void Dispose() | |||
{ | |||
return _run(fetches); | |||
} | |||
private unsafe byte[] _run(Tensor fetches) | |||
public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
{ | |||
var status = new Status(); | |||
c_api.TF_SessionRun(_session, | |||
run_options: null, | |||
inputs: new TF_Input[] { }, | |||
inputs: new TF_Output[] { }, | |||
input_values: new IntPtr[] { }, | |||
ninputs: 1, | |||
outputs: new TF_Output[] { }, | |||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||
_names_in_use = new Dictionary<string, int>(); | |||
} | |||
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, | |||
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | |||
TF_DataType[] input_types = null, string name = "", | |||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | |||
{ | |||
@@ -43,9 +43,13 @@ namespace Tensorflow | |||
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | |||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||
var op = new Operation(node_def, this, | |||
var op = new Operation(node_def, | |||
this, | |||
inputs: inputs, | |||
output_types: dtypes, | |||
control_inputs: new object[] { }, | |||
input_types: input_types, | |||
original_op: null, | |||
op_def: op_def); | |||
return op; | |||
@@ -73,6 +77,7 @@ namespace Tensorflow | |||
else | |||
{ | |||
_names_in_use[name_key] = 1; | |||
return name; | |||
} | |||
@@ -47,17 +47,11 @@ namespace Tensorflow | |||
} | |||
var attrs = new Dictionary<string, object>(); | |||
// Perform input type inference | |||
var inputs = new List<Tensor>(); | |||
var input_types = new List<DataType>(); | |||
foreach (var attr in op_def.Attr) | |||
{ | |||
if (keywords.ContainsKey(attr.Name)) | |||
{ | |||
attrs[attr.Name] = keywords[attr.Name]; | |||
} | |||
} | |||
foreach (var input_arg in op_def.InputArg) | |||
{ | |||
var input_name = input_arg.Name; | |||
@@ -70,18 +64,38 @@ namespace Tensorflow | |||
{ | |||
attrs[input_arg.TypeAttr] = DataType.DtFloat; | |||
} | |||
if (input_arg.IsRef) | |||
{ | |||
} | |||
else | |||
{ | |||
input_types.Add((keywords[input_name] as Tensor).dtype); | |||
} | |||
} | |||
// Process remaining attrs | |||
foreach (var attr in op_def.Attr) | |||
{ | |||
if (keywords.ContainsKey(attr.Name)) | |||
{ | |||
attrs[attr.Name] = keywords[attr.Name]; | |||
} | |||
} | |||
// Convert attr values to AttrValue protos. | |||
var attr_protos = new Dictionary<string, AttrValue>(); | |||
foreach (var attr_def in op_def.Attr) | |||
{ | |||
var key = attr_def.Name; | |||
var value = attrs[key]; | |||
var attr_value = new AttrValue(); | |||
switch (attr_def.Type) | |||
{ | |||
case "type": | |||
attr_value.Type = (DataType)keywords["dtype"]; | |||
attr_value.Type = _MakeType(value, attr_def); | |||
break; | |||
case "shape": | |||
attr_value.Shape = new TensorShapeProto(); | |||
@@ -91,6 +105,7 @@ namespace Tensorflow | |||
attr_protos[key] = attr_value; | |||
} | |||
// Determine output types (possibly using attrs) | |||
var output_types = new List<DataType>(); | |||
foreach (var arg in op_def.OutputArg) | |||
@@ -105,6 +120,7 @@ namespace Tensorflow | |||
} | |||
} | |||
// Add Op to graph | |||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||
name: scope, | |||
input_types: input_types.ToArray(), | |||
@@ -113,5 +129,10 @@ namespace Tensorflow | |||
return op; | |||
} | |||
public DataType _MakeType(Object v, AttrDef attr_def) | |||
{ | |||
return DataType.DtFloat; | |||
} | |||
} | |||
} |
@@ -8,7 +8,7 @@ namespace Tensorflow | |||
public class Operation | |||
{ | |||
private Graph _graph; | |||
private IntPtr _c_op; | |||
public IntPtr _c_op; | |||
public int _id => _id_value; | |||
private int _id_value; | |||
public string name; | |||
@@ -27,7 +27,7 @@ namespace Tensorflow | |||
c_api.TF_FinishOperation(desc, status.Handle); | |||
} | |||
public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
{ | |||
_graph = g; | |||
@@ -38,7 +38,7 @@ namespace Tensorflow | |||
_outputs = new Tensor[num_outputs]; | |||
for (int i = 0; i < num_outputs; i++) | |||
{ | |||
_outputs[i] = new Tensor(this, i, TF_DataType.DtDouble); | |||
_outputs[i] = new Tensor(this, i, TF_DataType.DtFloat); | |||
} | |||
_graph._add_op(this); | |||
@@ -6,11 +6,5 @@ namespace Tensorflow | |||
{ | |||
public class Session : BaseSession | |||
{ | |||
public override byte[] run(Tensor fetches) | |||
{ | |||
var ret = base.run(fetches); | |||
return ret; | |||
} | |||
} | |||
} |
@@ -6,9 +6,12 @@ namespace Tensorflow | |||
{ | |||
public class Tensor | |||
{ | |||
private Operation _op; | |||
private int _value_index; | |||
private readonly Operation _op; | |||
public Operation op => _op; | |||
private readonly int _value_index; | |||
public int value_index => _value_index; | |||
private DataType _dtype; | |||
public DataType dtype => _dtype; | |||
public Tensor(Operation op, int value_index, DataType dtype) | |||
{ | |||
@@ -16,5 +19,10 @@ namespace Tensorflow | |||
_value_index = value_index; | |||
_dtype = dtype; | |||
} | |||
public TF_Output _as_tf_output() | |||
{ | |||
return c_api_util.tf_output(_op._c_op, _value_index); | |||
} | |||
} | |||
} |
@@ -19,6 +19,14 @@ namespace Tensorflow | |||
{ | |||
public const string TensorFlowLibName = "tensorflow"; | |||
/// <summary> | |||
/// For inputs that take a single tensor. | |||
/// </summary> | |||
/// <param name="desc"></param> | |||
/// <param name="input"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); | |||
@@ -60,11 +68,11 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, | |||
TF_Input[] inputs, TF_Tensor[] input_values, | |||
int ninputs, TF_Output[] outputs, | |||
TF_Tensor[] output_values, int noutputs, | |||
TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, | |||
TF_Output[] outputs, TF_Tensor[] output_values, int noutputs, | |||
TF_Operation[] target_opers, int ntargets, | |||
TF_Buffer* run_metadata, TF_Status status); | |||
TF_Buffer* run_metadata, | |||
TF_Status status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | |||
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class c_api_util | |||
{ | |||
public static TF_Output tf_output(IntPtr c_op, int index) | |||
{ | |||
var ret = new TF_Output(); | |||
ret.oper = c_op; | |||
ret.index = index; | |||
return ret; | |||
} | |||
} | |||
} |
@@ -16,21 +16,35 @@ namespace Tensorflow | |||
return tf.Graph(); | |||
} | |||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | |||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||
{ | |||
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | |||
// Add inputs | |||
foreach(var op_input in inputs) | |||
{ | |||
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | |||
} | |||
var status = new Status(); | |||
// Add control inputs | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); | |||
var proto = Marshal.AllocHGlobal(bytes.Length); | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); | |||
if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||
} | |||
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | |||
if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||
return c_op; | |||
} | |||
@@ -17,7 +17,9 @@ namespace Tensorflow | |||
var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | |||
return null; | |||
var tensor = new Tensor(_op, 0, DataType.DtFloat); | |||
return tensor; | |||
} | |||
private static OpDefLibrary _InitOpDefLibrary() | |||
@@ -28,8 +28,14 @@ namespace TensorFlowNET.UnitTest | |||
var b = tf.placeholder(tf.float32); | |||
var c = tf.add(a, b); | |||
//sess.run(adder_node, { a: 3, b: 4.5}) | |||
//sess.run(adder_node, {a: [1,3], b: [2, 4]}) | |||
using(var sess = tf.Session()) | |||
{ | |||
var feed_dict = new Dictionary<Tensor, object>(); | |||
feed_dict.Add(a, 3); | |||
feed_dict.Add(b, 2); | |||
sess.run(c, feed_dict); | |||
} | |||
} | |||
} | |||
} |