Browse Source

protobuf #3

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
a4e4c36390
11 changed files with 202 additions and 28 deletions
  1. +1
    -0
      .gitignore
  2. +25
    -5
      src/TensorFlowNET.Core/Graph.cs
  3. +14
    -2
      src/TensorFlowNET.Core/Operation.cs
  4. +11
    -0
      src/TensorFlowNET.Core/Tensor.cs
  5. +4
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +31
    -0
      src/TensorFlowNET.Core/TensorShape.cs
  7. +16
    -3
      src/TensorFlowNET.Core/Tensorflow.cs
  8. +7
    -1
      src/TensorFlowNET.Core/c_api.cs
  9. +59
    -12
      src/TensorFlowNET.Core/ops.cs
  10. +33
    -3
      src/TensorFlowNET.Core/tensor_util.cs
  11. +1
    -2
      test/TensorFlowNET.Examples/HelloWorld.cs

+ 1
- 0
.gitignore View File

@@ -334,3 +334,4 @@ ASALocalRun/
/tensorflowlib/linux/native/libtensorflow_framework.so
/tensorflowlib/linux/native/libtensorflow.so
/src/TensorFlowNET.Core/libtensorflow.dll
/src/TensorFlowNET.Core/tensorflow.dll

+ 25
- 5
src/TensorFlowNET.Core/Graph.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;
using TF_DataType = Tensorflow.DataType;

namespace TensorFlowNET.Core
@@ -19,6 +19,7 @@ namespace TensorFlowNET.Core
public IntPtr handle;
private Dictionary<int, Operation> _nodes_by_id;
private Dictionary<string, Operation> _nodes_by_name;
private Dictionary<string, int> _names_in_use;
public int _version;
private int _next_id_counter;

@@ -27,17 +28,20 @@ namespace TensorFlowNET.Core
this.handle = graph;
_nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>();
_names_in_use = new Dictionary<string, int>();
}

public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "")
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, Dictionary<string, AttrValue> attrs = null, string name = "Const")
{
if (String.IsNullOrEmpty(name))
{
op_type = name;
name = op_type;
}

var op = new Operation(this, inputs);
op.name = name;
name = unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);

var op = new Operation(node_def, this, inputs, dtypes);

return op;
}
@@ -54,6 +58,22 @@ namespace TensorFlowNET.Core
return ++_next_id_counter;
}

public string unique_name(string name)
{
var name_key = name.ToLower();
if (_names_in_use.ContainsKey(name_key))
{
_names_in_use[name_key]++;
}
else
{
_names_in_use[name_key] = 1;
}

return $"{name}_{_names_in_use[name_key]}";
}

public Operation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();


+ 14
- 2
src/TensorFlowNET.Core/Operation.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TF_DataType = Tensorflow.DataType;

namespace TensorFlowNET.Core
{
@@ -11,13 +13,23 @@ namespace TensorFlowNET.Core
public int _id => _id_value;
private int _id_value;
public string name;
private Tensor[] _outputs;
public Tensor[] outputs => _outputs;

public Operation(Graph g, object inputs)
public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", string op_def = "")
{
_graph = g;

_id_value = _graph._next_id();
_c_op = ops._create_c_op(g, inputs);
_c_op = ops._create_c_op(g, node_def, inputs);
var num_outputs = c_api.TF_OperationNumOutputs(_c_op);

_outputs = new Tensor[num_outputs];
for (int i = 0; i < num_outputs; i++)
{
_outputs[i] = new Tensor(this, i, TF_DataType.DtDouble);
}

_graph._add_op(this);
}
}


+ 11
- 0
src/TensorFlowNET.Core/Tensor.cs View File

@@ -1,10 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Core
{
public class Tensor
{
private Operation _op;
private int _value_index;
private DataType _dtype;

public Tensor(Operation op, int value_index, DataType dtype)
{
_op = op;
_value_index = value_index;
_dtype = dtype;
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -9,6 +9,10 @@
<DefineConstants>DEBUG;TRACE</DefineConstants>
</PropertyGroup>

<ItemGroup>
<None Remove="Tensorflow\README.md" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
</ItemGroup>


+ 31
- 0
src/TensorFlowNET.Core/TensorShape.cs View File

@@ -0,0 +1,31 @@
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using tensor_shape_pb2 = Tensorflow;

namespace TensorFlowNET.Core
{
public class TensorShape
{
private int[] _dims;

public TensorShape()
{

}

public TensorShape as_shape()
{
return this;
}

public TensorShapeProto as_proto()
{
TensorShapeProto dim = new TensorShapeProto();

return new TensorShapeProto(dim);
}
}
}

+ 16
- 3
src/TensorFlowNET.Core/Tensorflow.cs View File

@@ -3,6 +3,8 @@ using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using TF_DataType = Tensorflow.DataType;
using attr_value_pb2 = Tensorflow;
using Tensorflow;

namespace TensorFlowNET.Core
{
@@ -13,9 +15,20 @@ namespace TensorFlowNET.Core
public static unsafe Tensor constant(object value)
{
var g = ops.get_default_graph();
g.create_op("Const", value, new TF_DataType[] { TF_DataType.DtDouble });

return new Tensor();
var tensor_value = new attr_value_pb2.AttrValue();
var tensor_pb = tensor_util.make_tensor_proto(value);
tensor_value.Tensor = tensor_pb;
var dtype_value = new attr_value_pb2.AttrValue
{
Type = tensor_value.Tensor.Dtype,
};

var attrs = new Dictionary<string, AttrValue>();
attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value;
var const_tensor = g.create_op("Const", null, new TF_DataType[] { dtype_value.Type }, attrs: attrs).outputs[0];

return const_tensor;
}

public static Deallocator FreeTensorDataDelegate = FreeTensorData;


+ 7
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -18,7 +18,7 @@ namespace TensorFlowNET.Core
{
public static class c_api
{
public const string TensorFlowLibName = "libtensorflow";
public const string TensorFlowLibName = "tensorflow";

[DllImport(TensorFlowLibName)]
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);
@@ -35,6 +35,12 @@ namespace TensorFlowNET.Core
[DllImport(TensorFlowLibName)]
public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);

[DllImport(TensorFlowLibName)]
public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, void* proto, size_t proto_len, TF_Status status);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status);



+ 59
- 12
src/TensorFlowNET.Core/ops.cs View File

@@ -3,8 +3,10 @@ using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Tensorflow;
using tf = TensorFlowNET.Core.Tensorflow;
using TF_DataType = Tensorflow.DataType;
using node_def_pb2 = Tensorflow;

namespace TensorFlowNET.Core
{
@@ -15,28 +17,73 @@ namespace TensorFlowNET.Core
return tf.Graph();
}

public static unsafe IntPtr _create_c_op(Graph graph, object inputs)
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
{
var op_desc = c_api.TF_NewOperation(graph.handle, "Const", "Const0");
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name);
var status = c_api.TF_NewStatus();

IntPtr tensor = IntPtr.Zero;
// Doesn't work
/*foreach(var attr in node_def.Attr)
{
if (attr.Value.Tensor != null)
{
switch (attr.Value.Tensor.Dtype)
{
case DataType.DtDouble:
var proto = (double*)Marshal.AllocHGlobal(sizeof(double));
*proto = attr.Value.Tensor.DoubleVal[0];
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status);
break;
}
}
else
{
//c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status);
}
} */

switch (inputs)
foreach (var attr in node_def.Attr)
{
case double value:
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
*v = value;
tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
break;
if (attr.Value.Tensor == null) continue;
switch (attr.Value.Tensor.Dtype)
{
case DataType.DtDouble:
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
*v = attr.Value.Tensor.DoubleVal[0];
var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
break;
case DataType.DtString:
var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8());
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status);
break;
}
}

c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);

var c_op = c_api.TF_FinishOperation(op_desc, status);

return c_op;
}

public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
{
var node_def = new node_def_pb2.NodeDef();
node_def.Op = op_type;
node_def.Name = name;

foreach (var attr in attrs)
{
node_def.Attr.Add(attr.Key, attr.Value);
}
return node_def;
}

public static int uid()
{
return 1;
}
}
}

+ 33
- 3
src/TensorFlowNET.Core/tensor_util.cs View File

@@ -1,15 +1,45 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using np = NumSharp.Core.NumPy;
using tensor_pb2 = Tensorflow;

namespace TensorFlowNET.Core
{
public static class tensor_util
{
public static void make_tensor_proto(object values, Type dtype = null)
public static TensorProto make_tensor_proto(object values, Type dtype = null)
{
var nparray = np.array(values as Array, dtype);
NDArray nparray;
TensorProto tensor_proto = null;
TensorShape tensor_shape = new TensorShape();

switch (values)
{
case double val:
nparray = np.array(new double[] { val }, np.float64);
tensor_proto = new tensor_pb2.TensorProto
{
Dtype = DataType.DtDouble,
TensorShape = tensor_shape.as_shape().as_proto()
};
tensor_proto.DoubleVal.Add(val);
break;

case string val:
nparray = np.array(new string[] { val }, np.chars);
tensor_proto = new tensor_pb2.TensorProto
{
Dtype = DataType.DtString,
TensorShape = tensor_shape.as_shape().as_proto()
};
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(val, Encoding.UTF8));
break;
}

return tensor_proto;
}
}
}

+ 1
- 2
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -19,8 +19,7 @@ namespace TensorFlowNET.Examples
The value returned by the constructor represents the output
of the Constant op.*/
var graph = tf.get_default_graph();
var hello = tf.constant(4.0);
//var hello = tf.constant("Hello, TensorFlow!");
var hello = tf.constant("Hello, TensorFlow!");

// Start tf session
// var sess = tf.Session();


Loading…
Cancel
Save