@@ -4,7 +4,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||
[](https://gitter.im/sci-sharp/community) | |||
 | |||
TensorFlow.NET is a member project of SciSharp stack. | |||
TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) stack. | |||
 | |||
@@ -45,3 +45,5 @@ using(var sess = tf.Session()) | |||
var o = sess.run(c, feed_dict); | |||
} | |||
``` | |||
Star me or raise issue on [Github](https://github.com/SciSharp/TensorFlow.NET) feel free. |
@@ -13,7 +13,7 @@ namespace Tensorflow | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// https://www.tensorflow.org/guide/graphs | |||
/// </summary> | |||
public class Graph | |||
public class Graph : IDisposable | |||
{ | |||
private IntPtr _handle; | |||
private Dictionary<int, Operation> _nodes_by_id; | |||
@@ -25,6 +25,11 @@ namespace Tensorflow | |||
private string _name_stack; | |||
public Graph() | |||
{ | |||
_handle = c_api.TF_NewGraph(); | |||
} | |||
public Graph(IntPtr graph) | |||
{ | |||
_handle = graph; | |||
@@ -171,6 +176,11 @@ namespace Tensorflow | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteGraph(_handle); | |||
} | |||
public static implicit operator IntPtr(Graph graph) | |||
{ | |||
return graph._handle; | |||
@@ -7,6 +7,14 @@ namespace Tensorflow | |||
{ | |||
public static partial class c_api | |||
{ | |||
/// <summary> | |||
/// Destroy an options object. Graph will be deleted once no more | |||
/// TFSession's are referencing it. | |||
/// </summary> | |||
/// <param name="graph"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_DeleteGraph(IntPtr graph); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
@@ -21,14 +29,14 @@ namespace Tensorflow | |||
/// <param name="num_dims"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Sets the shape of the Tensor referenced by `output` in `graph` to | |||
/// the shape described by `dims` and `num_dims`. | |||
/// </summary> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Returns the number of dimensions of the Tensor referenced by `output` | |||
@@ -14,16 +14,16 @@ namespace Tensorflow | |||
private Status status = new Status(); | |||
public string name => c_api.TF_OperationName(_handle); | |||
public string optype => c_api.TF_OperationOpType(_handle); | |||
public string device => c_api.TF_OperationDevice(_handle); | |||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
public string name { get; } | |||
public string optype { get; } | |||
public string device { get; } | |||
public int NumOutputs { get; } | |||
public TF_DataType OutputType { get; } | |||
public int OutputListLength { get; } | |||
public int NumInputs { get; } | |||
public int NumConsumers { get; } | |||
public int NumControlInputs { get; } | |||
public int NumControlOutputs { get; } | |||
private Tensor[] _outputs; | |||
public Tensor[] outputs => _outputs; | |||
@@ -31,7 +31,21 @@ namespace Tensorflow | |||
public Operation(IntPtr handle) | |||
{ | |||
if (handle == IntPtr.Zero) | |||
return; | |||
_handle = handle; | |||
name = c_api.TF_OperationName(_handle); | |||
optype = c_api.TF_OperationOpType(_handle); | |||
device = "";// c_api.TF_OperationDevice(_handle); | |||
NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||
OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
NumInputs = c_api.TF_OperationNumInputs(_handle); | |||
NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); | |||
NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); | |||
} | |||
public Operation(Graph g, string opType, string oper_name) | |||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||
this.index = index; | |||
} | |||
public IntPtr oper; | |||
public unsafe IntPtr oper; | |||
public int index; | |||
} | |||
} |
@@ -22,6 +22,15 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
/// <summary> | |||
/// For inputs that take a list of tensors. | |||
/// inputs must point to TF_Output[num_inputs]. | |||
/// </summary> | |||
/// <param name="desc"></param> | |||
/// <param name="inputs"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||
@@ -11,7 +11,7 @@ namespace Tensorflow | |||
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | |||
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | |||
/// </summary> | |||
public class Tensor | |||
public class Tensor : IDisposable | |||
{ | |||
private readonly IntPtr _handle; | |||
@@ -38,6 +38,7 @@ namespace Tensorflow | |||
/// n n-Tensor (you get the idea) | |||
/// </summary> | |||
public int rank; | |||
public int NDims => rank; | |||
/// <summary> | |||
/// if original buffer is free. | |||
@@ -96,7 +97,7 @@ namespace Tensorflow | |||
nd.shape.Select(x => (long)x).ToArray(), // shape | |||
nd.ndim, | |||
dotHandle, | |||
(UIntPtr)(nd.size * nd.dtypesize), | |||
(ulong)(nd.size * nd.dtypesize), | |||
(IntPtr values, IntPtr len, ref bool closure) => | |||
{ | |||
// Free the original buffer and set flag | |||
@@ -160,9 +161,19 @@ namespace Tensorflow | |||
return TF_DataType.DtInvalid; | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteTensor(_handle); | |||
} | |||
public static implicit operator IntPtr(Tensor tensor) | |||
{ | |||
return tensor._handle; | |||
} | |||
public static implicit operator Tensor(IntPtr handle) | |||
{ | |||
return new Tensor(handle); | |||
} | |||
} | |||
} |
@@ -7,20 +7,23 @@ namespace Tensorflow | |||
{ | |||
public static partial class c_api | |||
{ | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); | |||
/// <summary> | |||
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | |||
/// </summary> | |||
/// <param name="dt"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); | |||
public static extern ulong TF_DataTypeSize(TF_DataType dt); | |||
/// <summary> | |||
/// Destroy a tensor. | |||
/// </summary> | |||
/// <param name="tensor"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern void TF_DeleteTensor(IntPtr tensor); | |||
public static extern void TF_DeleteTensor(IntPtr tensor); | |||
/// <summary> | |||
/// Return the length of the tensor in the "dim_index" dimension. | |||
@@ -30,7 +33,7 @@ namespace Tensorflow | |||
/// <param name="dim_index"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index); | |||
public static extern long TF_Dim(IntPtr tensor, int dim_index); | |||
/// <summary> | |||
/// Return a new tensor that holds the bytes data[0,len-1] | |||
@@ -44,7 +47,7 @@ namespace Tensorflow | |||
/// <param name="deallocator_arg"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); | |||
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); | |||
/// <summary> | |||
/// Return the number of dimensions that the tensor has. | |||
@@ -52,7 +55,7 @@ namespace Tensorflow | |||
/// <param name="tensor"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe int TF_NumDims(IntPtr tensor); | |||
public static extern int TF_NumDims(IntPtr tensor); | |||
/// <summary> | |||
/// Return the size of the underlying data in bytes. | |||
@@ -60,7 +63,7 @@ namespace Tensorflow | |||
/// <param name="tensor"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); | |||
public static extern ulong TF_TensorByteSize(IntPtr tensor); | |||
/// <summary> | |||
/// Return a pointer to the underlying data buffer. | |||
@@ -68,7 +71,7 @@ namespace Tensorflow | |||
/// <param name="tensor"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | |||
public static extern IntPtr TF_TensorData(IntPtr tensor); | |||
/// <summary> | |||
/// Return the type of a tensor element. | |||
@@ -76,6 +79,6 @@ namespace Tensorflow | |||
/// <param name="tensor"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); | |||
public static extern TF_DataType TF_TensorType(IntPtr tensor); | |||
} | |||
} |
@@ -7,6 +7,7 @@ namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// C API for TensorFlow. | |||
/// Port from tensorflow\c\c_api.h | |||
/// | |||
/// The API leans towards simplicity and uniformity instead of convenience | |||
/// since most usage will be by language specific wrappers. | |||
@@ -9,17 +9,21 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class GraphTest | |||
{ | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
/// `TEST(CAPI, Graph)` | |||
/// </summary> | |||
[TestMethod] | |||
public void Graph() | |||
public void c_api_Graph() | |||
{ | |||
var s = new Status(); | |||
var graph = tf.get_default_graph(); | |||
var graph = new Graph(); | |||
// Make a placeholder operation. | |||
var feed = c_test_util.Placeholder(graph, s); | |||
Assert.AreEqual("feed", feed.name); | |||
Assert.AreEqual("Placeholder", feed.optype); | |||
//Assert.AreEqual("", feed.device); | |||
Assert.AreEqual("", feed.device); | |||
Assert.AreEqual(1, feed.NumOutputs); | |||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | |||
Assert.AreEqual(1, feed.OutputListLength); | |||
@@ -30,6 +34,19 @@ namespace TensorFlowNET.UnitTest | |||
AttrValue attr_value = null; | |||
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | |||
Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||
// Test not found errors in TF_Operation*() query functions. | |||
// Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
// Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
// Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
// Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
// Make a constant oper with the scalar "3". | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
// Add oper. | |||
var add = c_test_util.Add(feed, three, graph, s); | |||
} | |||
} | |||
} |
@@ -10,8 +10,12 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class OperationsTest | |||
{ | |||
/// <summary> | |||
/// Port from tensorflow\c\c_api_test.cc | |||
/// `TEST(CAPI, GetAllOpList)` | |||
/// </summary> | |||
[TestMethod] | |||
public void GetAllOpList() | |||
public void c_api_GetAllOpList() | |||
{ | |||
var handle = c_api.TF_GetAllOpList(); | |||
var buffer = new Buffer(handle); | |||
@@ -12,8 +12,29 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class TensorTest | |||
{ | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
/// `TEST(CAPI, AllocateTensor)` | |||
/// </summary> | |||
[TestMethod] | |||
public void c_api_AllocateTensor() | |||
{ | |||
ulong num_bytes = 6 * sizeof(float); | |||
long[] dims = { 2, 3 }; | |||
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||
Assert.AreEqual(2, t.NDims); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||
Assert.AreEqual(num_bytes, t.bytesize); | |||
t.Dispose(); | |||
} | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
/// `TEST(CAPI, Tensor)` | |||
/// </summary> | |||
[TestMethod] | |||
public void NewTensor() | |||
public void c_api_Tensor() | |||
{ | |||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||
@@ -30,46 +51,38 @@ namespace TensorFlowNET.UnitTest | |||
/// <summary> | |||
/// Port from tensorflow\c\c_api_test.cc | |||
/// `TEST(CAPI, SetShape)` | |||
/// </summary> | |||
[TestMethod] | |||
public void SetShape() | |||
public void c_api_SetShape() | |||
{ | |||
var s = new Status(); | |||
var graph = tf.get_default_graph(); | |||
var graph = new Graph(); | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); | |||
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); | |||
//if (!dims.empty()) | |||
{ | |||
//TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); | |||
} | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsNotNull(op); | |||
var feed = c_test_util.Placeholder(graph, s); | |||
var feed_out_0 = new TF_Output(feed, 0); | |||
// Fetch the shape, it should be completely unknown. | |||
var feed_out_0 = new TF_Output { oper = op, index = 0 }; | |||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(-1, num_dims); | |||
// Set the shape to be unknown, expect no change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(-1, num_dims); | |||
// Set the shape to be 2 x Unknown | |||
var dims = new int[] { 2, -1 }; | |||
long[] dims = { 2, -1 }; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(2, num_dims); | |||
// Get the dimension vector appropriately. | |||
var returned_dims = new int[dims.Length]; | |||
var returned_dims = new long[dims.Length]; | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
@@ -77,19 +90,57 @@ namespace TensorFlowNET.UnitTest | |||
// Set to a new valid shape: [2, 3] | |||
dims[1] = 3; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
// Fetch and see that the new value is returned. | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
//Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
// Try to set 'unknown' with unknown rank on the shape and see that | |||
// it doesn't change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(2, num_dims); | |||
Assert.AreEqual(2, returned_dims[0]); | |||
Assert.AreEqual(3, returned_dims[1]); | |||
// Try to set 'unknown' with same rank on the shape and see that | |||
// it doesn't change. | |||
dims[0] = -1; | |||
dims[1] = -1; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(2, num_dims); | |||
Assert.AreEqual(2, returned_dims[0]); | |||
Assert.AreEqual(3, returned_dims[1]); | |||
// Try to fetch a shape with the wrong num_dims | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
// Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||
dims[1] = 5; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
// Test for a scalar. | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
var three_out_0 = new TF_Output { oper = three }; | |||
var three_out_0 = new TF_Output(three, 0); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(0, num_dims); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
} | |||
} |
@@ -7,8 +7,32 @@ using Buffer = Tensorflow.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
/// <summary> | |||
/// Port from `tensorflow\c\c_test_util.cc` | |||
/// </summary> | |||
public static class c_test_util | |||
{ | |||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | |||
{ | |||
Operation op = null; | |||
AddOpHelper(l, r, graph, s, name, ref op, true); | |||
return op; | |||
} | |||
public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||
c_api.TF_AddInputList(desc, new TF_Output[] | |||
{ | |||
new TF_Output(l, 0), | |||
new TF_Output(r, 0), | |||
}, 2); | |||
op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
} | |||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
{ | |||
var buffer = new Buffer(); | |||
@@ -58,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||
return op; | |||
} | |||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | |||
{ | |||
return Const(new Tensor(v), graph, s, name); | |||
} | |||