Browse Source

sess.run not finished yet

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
2995c36c9d
11 changed files with 117 additions and 36 deletions
  1. +10
    -5
      src/TensorFlowNET.Core/BaseSession.cs
  2. +7
    -2
      src/TensorFlowNET.Core/Graph.cs
  3. +31
    -10
      src/TensorFlowNET.Core/OpDefLibrary.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Operation.cs
  5. +0
    -6
      src/TensorFlowNET.Core/Session.cs
  6. +10
    -2
      src/TensorFlowNET.Core/Tensor.cs
  7. +12
    -4
      src/TensorFlowNET.Core/c_api.cs
  8. +18
    -0
      src/TensorFlowNET.Core/c_api_util.cs
  9. +15
    -1
      src/TensorFlowNET.Core/ops.cs
  10. +3
    -1
      src/TensorFlowNET.Core/ops/gen_math_ops.cs
  11. +8
    -2
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 10
- 5
src/TensorFlowNET.Core/BaseSession.cs View File

@@ -4,7 +4,7 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
public class BaseSession
public class BaseSession : IDisposable
{ {
private Graph _graph; private Graph _graph;
private bool _opened; private bool _opened;
@@ -32,18 +32,23 @@ namespace Tensorflow
c_api.TF_DeleteSessionOptions(opts); c_api.TF_DeleteSessionOptions(opts);
} }


public virtual byte[] run(Tensor fetches)
public void Dispose()
{ {
return _run(fetches);
} }


private unsafe byte[] _run(Tensor fetches)
public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
return _run(fetches, feed_dict);
}

private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{ {
var status = new Status(); var status = new Status();


c_api.TF_SessionRun(_session, c_api.TF_SessionRun(_session,
run_options: null, run_options: null,
inputs: new TF_Input[] { },
inputs: new TF_Output[] { },
input_values: new IntPtr[] { }, input_values: new IntPtr[] { },
ninputs: 1, ninputs: 1,
outputs: new TF_Output[] { }, outputs: new TF_Output[] { },


+ 7
- 2
src/TensorFlowNET.Core/Graph.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
} }


public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes,
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = "", TF_DataType[] input_types = null, string name = "",
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
{ {
@@ -43,9 +43,13 @@ namespace Tensorflow
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);


var op = new Operation(node_def, this,
var op = new Operation(node_def,
this,
inputs: inputs, inputs: inputs,
output_types: dtypes, output_types: dtypes,
control_inputs: new object[] { },
input_types: input_types,
original_op: null,
op_def: op_def); op_def: op_def);


return op; return op;
@@ -73,6 +77,7 @@ namespace Tensorflow
else else
{ {
_names_in_use[name_key] = 1; _names_in_use[name_key] = 1;
return name;
} }




+ 31
- 10
src/TensorFlowNET.Core/OpDefLibrary.cs View File

@@ -47,17 +47,11 @@ namespace Tensorflow
} }


var attrs = new Dictionary<string, object>(); var attrs = new Dictionary<string, object>();

// Perform input type inference
var inputs = new List<Tensor>(); var inputs = new List<Tensor>();
var input_types = new List<DataType>(); var input_types = new List<DataType>();

foreach (var attr in op_def.Attr)
{
if (keywords.ContainsKey(attr.Name))
{
attrs[attr.Name] = keywords[attr.Name];
}
}

foreach (var input_arg in op_def.InputArg) foreach (var input_arg in op_def.InputArg)
{ {
var input_name = input_arg.Name; var input_name = input_arg.Name;
@@ -70,18 +64,38 @@ namespace Tensorflow
{ {
attrs[input_arg.TypeAttr] = DataType.DtFloat; attrs[input_arg.TypeAttr] = DataType.DtFloat;
} }

if (input_arg.IsRef)
{

}
else
{
input_types.Add((keywords[input_name] as Tensor).dtype);
}
} }


// Process remaining attrs
foreach (var attr in op_def.Attr)
{
if (keywords.ContainsKey(attr.Name))
{
attrs[attr.Name] = keywords[attr.Name];
}
}

// Convert attr values to AttrValue protos.
var attr_protos = new Dictionary<string, AttrValue>(); var attr_protos = new Dictionary<string, AttrValue>();
foreach (var attr_def in op_def.Attr) foreach (var attr_def in op_def.Attr)
{ {
var key = attr_def.Name; var key = attr_def.Name;
var value = attrs[key];
var attr_value = new AttrValue(); var attr_value = new AttrValue();
switch (attr_def.Type) switch (attr_def.Type)
{ {
case "type": case "type":
attr_value.Type = (DataType)keywords["dtype"];
attr_value.Type = _MakeType(value, attr_def);
break; break;
case "shape": case "shape":
attr_value.Shape = new TensorShapeProto(); attr_value.Shape = new TensorShapeProto();
@@ -91,6 +105,7 @@ namespace Tensorflow
attr_protos[key] = attr_value; attr_protos[key] = attr_value;
} }


// Determine output types (possibly using attrs)
var output_types = new List<DataType>(); var output_types = new List<DataType>();


foreach (var arg in op_def.OutputArg) foreach (var arg in op_def.OutputArg)
@@ -105,6 +120,7 @@ namespace Tensorflow
} }
} }


// Add Op to graph
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
name: scope, name: scope,
input_types: input_types.ToArray(), input_types: input_types.ToArray(),
@@ -113,5 +129,10 @@ namespace Tensorflow


return op; return op;
} }

public DataType _MakeType(Object v, AttrDef attr_def)
{
return DataType.DtFloat;
}
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Operation.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
public class Operation public class Operation
{ {
private Graph _graph; private Graph _graph;
private IntPtr _c_op;
public IntPtr _c_op;
public int _id => _id_value; public int _id => _id_value;
private int _id_value; private int _id_value;
public string name; public string name;
@@ -27,7 +27,7 @@ namespace Tensorflow
c_api.TF_FinishOperation(desc, status.Handle); c_api.TF_FinishOperation(desc, status.Handle);
} }


public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{ {
_graph = g; _graph = g;


@@ -38,7 +38,7 @@ namespace Tensorflow
_outputs = new Tensor[num_outputs]; _outputs = new Tensor[num_outputs];
for (int i = 0; i < num_outputs; i++) for (int i = 0; i < num_outputs; i++)
{ {
_outputs[i] = new Tensor(this, i, TF_DataType.DtDouble);
_outputs[i] = new Tensor(this, i, TF_DataType.DtFloat);
} }


_graph._add_op(this); _graph._add_op(this);


+ 0
- 6
src/TensorFlowNET.Core/Session.cs View File

@@ -6,11 +6,5 @@ namespace Tensorflow
{ {
public class Session : BaseSession public class Session : BaseSession
{ {
public override byte[] run(Tensor fetches)
{
var ret = base.run(fetches);

return ret;
}
} }
} }

+ 10
- 2
src/TensorFlowNET.Core/Tensor.cs View File

@@ -6,9 +6,12 @@ namespace Tensorflow
{ {
public class Tensor public class Tensor
{ {
private Operation _op;
private int _value_index;
private readonly Operation _op;
public Operation op => _op;
private readonly int _value_index;
public int value_index => _value_index;
private DataType _dtype; private DataType _dtype;
public DataType dtype => _dtype;


public Tensor(Operation op, int value_index, DataType dtype) public Tensor(Operation op, int value_index, DataType dtype)
{ {
@@ -16,5 +19,10 @@ namespace Tensorflow
_value_index = value_index; _value_index = value_index;
_dtype = dtype; _dtype = dtype;
} }

public TF_Output _as_tf_output()
{
return c_api_util.tf_output(_op._c_op, _value_index);
}
} }
} }

+ 12
- 4
src/TensorFlowNET.Core/c_api.cs View File

@@ -19,6 +19,14 @@ namespace Tensorflow
{ {
public const string TensorFlowLibName = "tensorflow"; public const string TensorFlowLibName = "tensorflow";


/// <summary>
/// For inputs that take a single tensor.
/// </summary>
/// <param name="desc"></param>
/// <param name="input"></param>
[DllImport(TensorFlowLibName)]
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts);


@@ -60,11 +68,11 @@ namespace Tensorflow


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options,
TF_Input[] inputs, TF_Tensor[] input_values,
int ninputs, TF_Output[] outputs,
TF_Tensor[] output_values, int noutputs,
TF_Output[] inputs, TF_Tensor[] input_values, int ninputs,
TF_Output[] outputs, TF_Tensor[] output_values, int noutputs,
TF_Operation[] target_opers, int ntargets, TF_Operation[] target_opers, int ntargets,
TF_Buffer* run_metadata, TF_Status status);
TF_Buffer* run_metadata,
TF_Status status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);


+ 18
- 0
src/TensorFlowNET.Core/c_api_util.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class c_api_util
{
public static TF_Output tf_output(IntPtr c_op, int index)
{
var ret = new TF_Output();
ret.oper = c_op;
ret.index = index;

return ret;
}
}
}

+ 15
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -16,21 +16,35 @@ namespace Tensorflow
return tf.Graph(); return tf.Graph();
} }


public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
{ {
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);

// Add inputs
foreach(var op_input in inputs)
{
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}

var status = new Status(); var status = new Status();


// Add control inputs

// Add attrs
foreach (var attr in node_def.Attr) foreach (var attr in node_def.Attr)
{ {
var bytes = attr.Value.ToByteArray(); var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length); var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);

if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
} }


var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);


if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);

return c_op; return c_op;
} }




+ 3
- 1
src/TensorFlowNET.Core/ops/gen_math_ops.cs View File

@@ -17,7 +17,9 @@ namespace Tensorflow


var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords);


return null;
var tensor = new Tensor(_op, 0, DataType.DtFloat);

return tensor;
} }


private static OpDefLibrary _InitOpDefLibrary() private static OpDefLibrary _InitOpDefLibrary()


+ 8
- 2
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -28,8 +28,14 @@ namespace TensorFlowNET.UnitTest
var b = tf.placeholder(tf.float32); var b = tf.placeholder(tf.float32);
var c = tf.add(a, b); var c = tf.add(a, b);


//sess.run(adder_node, { a: 3, b: 4.5})
//sess.run(adder_node, {a: [1,3], b: [2, 4]})
using(var sess = tf.Session())
{
var feed_dict = new Dictionary<Tensor, object>();
feed_dict.Add(a, 3);
feed_dict.Add(b, 2);

sess.run(c, feed_dict);
}
} }
} }
} }

Loading…
Cancel
Save