@@ -4,7 +4,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||||
[](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
 |  | ||||
TensorFlow.NET is a member project of SciSharp stack. | |||||
TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) stack. | |||||
 |  | ||||
@@ -45,3 +45,5 @@ using(var sess = tf.Session()) | |||||
var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
} | } | ||||
``` | ``` | ||||
Star me or raise issue on [Github](https://github.com/SciSharp/TensorFlow.NET) feel free. |
@@ -13,7 +13,7 @@ namespace Tensorflow | |||||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
/// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
/// </summary> | /// </summary> | ||||
public class Graph | |||||
public class Graph : IDisposable | |||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
@@ -25,6 +25,11 @@ namespace Tensorflow | |||||
private string _name_stack; | private string _name_stack; | ||||
public Graph() | |||||
{ | |||||
_handle = c_api.TF_NewGraph(); | |||||
} | |||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
_handle = graph; | _handle = graph; | ||||
@@ -171,6 +176,11 @@ namespace Tensorflow | |||||
return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
} | } | ||||
public void Dispose() | |||||
{ | |||||
c_api.TF_DeleteGraph(_handle); | |||||
} | |||||
public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
{ | { | ||||
return graph._handle; | return graph._handle; | ||||
@@ -7,6 +7,14 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class c_api | public static partial class c_api | ||||
{ | { | ||||
/// <summary> | |||||
/// Destroy an options object. Graph will be deleted once no more | |||||
/// TFSession's are referencing it. | |||||
/// </summary> | |||||
/// <param name="graph"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_DeleteGraph(IntPtr graph); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | ||||
@@ -21,14 +29,14 @@ namespace Tensorflow | |||||
/// <param name="num_dims"></param> | /// <param name="num_dims"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Sets the shape of the Tensor referenced by `output` in `graph` to | /// Sets the shape of the Tensor referenced by `output` in `graph` to | ||||
/// the shape described by `dims` and `num_dims`. | /// the shape described by `dims` and `num_dims`. | ||||
/// </summary> | /// </summary> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the number of dimensions of the Tensor referenced by `output` | /// Returns the number of dimensions of the Tensor referenced by `output` | ||||
@@ -14,16 +14,16 @@ namespace Tensorflow | |||||
private Status status = new Status(); | private Status status = new Status(); | ||||
public string name => c_api.TF_OperationName(_handle); | |||||
public string optype => c_api.TF_OperationOpType(_handle); | |||||
public string device => c_api.TF_OperationDevice(_handle); | |||||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||||
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||||
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||||
public string name { get; } | |||||
public string optype { get; } | |||||
public string device { get; } | |||||
public int NumOutputs { get; } | |||||
public TF_DataType OutputType { get; } | |||||
public int OutputListLength { get; } | |||||
public int NumInputs { get; } | |||||
public int NumConsumers { get; } | |||||
public int NumControlInputs { get; } | |||||
public int NumControlOutputs { get; } | |||||
private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
@@ -31,7 +31,21 @@ namespace Tensorflow | |||||
public Operation(IntPtr handle) | public Operation(IntPtr handle) | ||||
{ | { | ||||
if (handle == IntPtr.Zero) | |||||
return; | |||||
_handle = handle; | _handle = handle; | ||||
name = c_api.TF_OperationName(_handle); | |||||
optype = c_api.TF_OperationOpType(_handle); | |||||
device = "";// c_api.TF_OperationDevice(_handle); | |||||
NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||||
OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
NumInputs = c_api.TF_OperationNumInputs(_handle); | |||||
NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); | |||||
NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); | |||||
} | } | ||||
public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||||
this.index = index; | this.index = index; | ||||
} | } | ||||
public IntPtr oper; | |||||
public unsafe IntPtr oper; | |||||
public int index; | public int index; | ||||
} | } | ||||
} | } |
@@ -22,6 +22,15 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_AddInput(IntPtr desc, TF_Output input); | public static extern void TF_AddInput(IntPtr desc, TF_Output input); | ||||
/// <summary> | |||||
/// For inputs that take a list of tensors. | |||||
/// inputs must point to TF_Output[num_inputs]. | |||||
/// </summary> | |||||
/// <param name="desc"></param> | |||||
/// <param name="inputs"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | ||||
@@ -11,7 +11,7 @@ namespace Tensorflow | |||||
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
/// </summary> | /// </summary> | ||||
public class Tensor | |||||
public class Tensor : IDisposable | |||||
{ | { | ||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
@@ -38,6 +38,7 @@ namespace Tensorflow | |||||
/// n n-Tensor (you get the idea) | /// n n-Tensor (you get the idea) | ||||
/// </summary> | /// </summary> | ||||
public int rank; | public int rank; | ||||
public int NDims => rank; | |||||
/// <summary> | /// <summary> | ||||
/// if original buffer is free. | /// if original buffer is free. | ||||
@@ -96,7 +97,7 @@ namespace Tensorflow | |||||
nd.shape.Select(x => (long)x).ToArray(), // shape | nd.shape.Select(x => (long)x).ToArray(), // shape | ||||
nd.ndim, | nd.ndim, | ||||
dotHandle, | dotHandle, | ||||
(UIntPtr)(nd.size * nd.dtypesize), | |||||
(ulong)(nd.size * nd.dtypesize), | |||||
(IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
{ | { | ||||
// Free the original buffer and set flag | // Free the original buffer and set flag | ||||
@@ -160,9 +161,19 @@ namespace Tensorflow | |||||
return TF_DataType.DtInvalid; | return TF_DataType.DtInvalid; | ||||
} | } | ||||
public void Dispose() | |||||
{ | |||||
c_api.TF_DeleteTensor(_handle); | |||||
} | |||||
public static implicit operator IntPtr(Tensor tensor) | public static implicit operator IntPtr(Tensor tensor) | ||||
{ | { | ||||
return tensor._handle; | return tensor._handle; | ||||
} | } | ||||
public static implicit operator Tensor(IntPtr handle) | |||||
{ | |||||
return new Tensor(handle); | |||||
} | |||||
} | } | ||||
} | } |
@@ -7,20 +7,23 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class c_api | public static partial class c_api | ||||
{ | { | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); | |||||
/// <summary> | /// <summary> | ||||
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="dt"></param> | /// <param name="dt"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); | |||||
public static extern ulong TF_DataTypeSize(TF_DataType dt); | |||||
/// <summary> | /// <summary> | ||||
/// Destroy a tensor. | /// Destroy a tensor. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern void TF_DeleteTensor(IntPtr tensor); | |||||
public static extern void TF_DeleteTensor(IntPtr tensor); | |||||
/// <summary> | /// <summary> | ||||
/// Return the length of the tensor in the "dim_index" dimension. | /// Return the length of the tensor in the "dim_index" dimension. | ||||
@@ -30,7 +33,7 @@ namespace Tensorflow | |||||
/// <param name="dim_index"></param> | /// <param name="dim_index"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index); | |||||
public static extern long TF_Dim(IntPtr tensor, int dim_index); | |||||
/// <summary> | /// <summary> | ||||
/// Return a new tensor that holds the bytes data[0,len-1] | /// Return a new tensor that holds the bytes data[0,len-1] | ||||
@@ -44,7 +47,7 @@ namespace Tensorflow | |||||
/// <param name="deallocator_arg"></param> | /// <param name="deallocator_arg"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); | |||||
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); | |||||
/// <summary> | /// <summary> | ||||
/// Return the number of dimensions that the tensor has. | /// Return the number of dimensions that the tensor has. | ||||
@@ -52,7 +55,7 @@ namespace Tensorflow | |||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe int TF_NumDims(IntPtr tensor); | |||||
public static extern int TF_NumDims(IntPtr tensor); | |||||
/// <summary> | /// <summary> | ||||
/// Return the size of the underlying data in bytes. | /// Return the size of the underlying data in bytes. | ||||
@@ -60,7 +63,7 @@ namespace Tensorflow | |||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); | |||||
public static extern ulong TF_TensorByteSize(IntPtr tensor); | |||||
/// <summary> | /// <summary> | ||||
/// Return a pointer to the underlying data buffer. | /// Return a pointer to the underlying data buffer. | ||||
@@ -68,7 +71,7 @@ namespace Tensorflow | |||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | |||||
public static extern IntPtr TF_TensorData(IntPtr tensor); | |||||
/// <summary> | /// <summary> | ||||
/// Return the type of a tensor element. | /// Return the type of a tensor element. | ||||
@@ -76,6 +79,6 @@ namespace Tensorflow | |||||
/// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); | |||||
public static extern TF_DataType TF_TensorType(IntPtr tensor); | |||||
} | } | ||||
} | } |
@@ -7,6 +7,7 @@ namespace Tensorflow | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// C API for TensorFlow. | /// C API for TensorFlow. | ||||
/// Port from tensorflow\c\c_api.h | |||||
/// | /// | ||||
/// The API leans towards simplicity and uniformity instead of convenience | /// The API leans towards simplicity and uniformity instead of convenience | ||||
/// since most usage will be by language specific wrappers. | /// since most usage will be by language specific wrappers. | ||||
@@ -9,17 +9,21 @@ namespace TensorFlowNET.UnitTest | |||||
[TestClass] | [TestClass] | ||||
public class GraphTest | public class GraphTest | ||||
{ | { | ||||
/// <summary> | |||||
/// Port from c_api_test.cc | |||||
/// `TEST(CAPI, Graph)` | |||||
/// </summary> | |||||
[TestMethod] | [TestMethod] | ||||
public void Graph() | |||||
public void c_api_Graph() | |||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = tf.get_default_graph(); | |||||
var graph = new Graph(); | |||||
// Make a placeholder operation. | // Make a placeholder operation. | ||||
var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
Assert.AreEqual("feed", feed.name); | Assert.AreEqual("feed", feed.name); | ||||
Assert.AreEqual("Placeholder", feed.optype); | Assert.AreEqual("Placeholder", feed.optype); | ||||
//Assert.AreEqual("", feed.device); | |||||
Assert.AreEqual("", feed.device); | |||||
Assert.AreEqual(1, feed.NumOutputs); | Assert.AreEqual(1, feed.NumOutputs); | ||||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | ||||
Assert.AreEqual(1, feed.OutputListLength); | Assert.AreEqual(1, feed.OutputListLength); | ||||
@@ -30,6 +34,19 @@ namespace TensorFlowNET.UnitTest | |||||
AttrValue attr_value = null; | AttrValue attr_value = null; | ||||
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | ||||
Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||||
// Test not found errors in TF_Operation*() query functions. | |||||
// Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
// Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
// Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
// Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
// Make a constant oper with the scalar "3". | |||||
var three = c_test_util.ScalarConst(3, graph, s); | |||||
// Add oper. | |||||
var add = c_test_util.Add(feed, three, graph, s); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -10,8 +10,12 @@ namespace TensorFlowNET.UnitTest | |||||
[TestClass] | [TestClass] | ||||
public class OperationsTest | public class OperationsTest | ||||
{ | { | ||||
/// <summary> | |||||
/// Port from tensorflow\c\c_api_test.cc | |||||
/// `TEST(CAPI, GetAllOpList)` | |||||
/// </summary> | |||||
[TestMethod] | [TestMethod] | ||||
public void GetAllOpList() | |||||
public void c_api_GetAllOpList() | |||||
{ | { | ||||
var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
@@ -12,8 +12,29 @@ namespace TensorFlowNET.UnitTest | |||||
[TestClass] | [TestClass] | ||||
public class TensorTest | public class TensorTest | ||||
{ | { | ||||
/// <summary> | |||||
/// Port from c_api_test.cc | |||||
/// `TEST(CAPI, AllocateTensor)` | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void c_api_AllocateTensor() | |||||
{ | |||||
ulong num_bytes = 6 * sizeof(float); | |||||
long[] dims = { 2, 3 }; | |||||
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||||
Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||||
Assert.AreEqual(2, t.NDims); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||||
Assert.AreEqual(num_bytes, t.bytesize); | |||||
t.Dispose(); | |||||
} | |||||
/// <summary> | |||||
/// Port from c_api_test.cc | |||||
/// `TEST(CAPI, Tensor)` | |||||
/// </summary> | |||||
[TestMethod] | [TestMethod] | ||||
public void NewTensor() | |||||
public void c_api_Tensor() | |||||
{ | { | ||||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | ||||
@@ -30,46 +51,38 @@ namespace TensorFlowNET.UnitTest | |||||
/// <summary> | /// <summary> | ||||
/// Port from tensorflow\c\c_api_test.cc | /// Port from tensorflow\c\c_api_test.cc | ||||
/// `TEST(CAPI, SetShape)` | |||||
/// </summary> | /// </summary> | ||||
[TestMethod] | [TestMethod] | ||||
public void SetShape() | |||||
public void c_api_SetShape() | |||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = tf.get_default_graph(); | |||||
var graph = new Graph(); | |||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); | |||||
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); | |||||
//if (!dims.empty()) | |||||
{ | |||||
//TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); | |||||
} | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.IsNotNull(op); | |||||
var feed = c_test_util.Placeholder(graph, s); | |||||
var feed_out_0 = new TF_Output(feed, 0); | |||||
// Fetch the shape, it should be completely unknown. | // Fetch the shape, it should be completely unknown. | ||||
var feed_out_0 = new TF_Output { oper = op, index = 0 }; | |||||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
Assert.AreEqual(-1, num_dims); | Assert.AreEqual(-1, num_dims); | ||||
// Set the shape to be unknown, expect no change. | // Set the shape to be unknown, expect no change. | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
Assert.AreEqual(-1, num_dims); | Assert.AreEqual(-1, num_dims); | ||||
// Set the shape to be 2 x Unknown | // Set the shape to be 2 x Unknown | ||||
var dims = new int[] { 2, -1 }; | |||||
long[] dims = { 2, -1 }; | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
Assert.AreEqual(2, num_dims); | Assert.AreEqual(2, num_dims); | ||||
// Get the dimension vector appropriately. | // Get the dimension vector appropriately. | ||||
var returned_dims = new int[dims.Length]; | |||||
var returned_dims = new long[dims.Length]; | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | ||||
@@ -77,19 +90,57 @@ namespace TensorFlowNET.UnitTest | |||||
// Set to a new valid shape: [2, 3] | // Set to a new valid shape: [2, 3] | ||||
dims[1] = 3; | dims[1] = 3; | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | ||||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
// Fetch and see that the new value is returned. | // Fetch and see that the new value is returned. | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
//Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
// Try to set 'unknown' with unknown rank on the shape and see that | |||||
// it doesn't change. | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.AreEqual(2, num_dims); | |||||
Assert.AreEqual(2, returned_dims[0]); | |||||
Assert.AreEqual(3, returned_dims[1]); | |||||
// Try to set 'unknown' with same rank on the shape and see that | |||||
// it doesn't change. | |||||
dims[0] = -1; | |||||
dims[1] = -1; | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.AreEqual(2, num_dims); | |||||
Assert.AreEqual(2, returned_dims[0]); | |||||
Assert.AreEqual(3, returned_dims[1]); | |||||
// Try to fetch a shape with the wrong num_dims | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
// Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||||
dims[1] = 5; | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
// Test for a scalar. | // Test for a scalar. | ||||
var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
var three_out_0 = new TF_Output { oper = three }; | |||||
var three_out_0 = new TF_Output(three, 0); | |||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | ||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
Assert.AreEqual(0, num_dims); | Assert.AreEqual(0, num_dims); | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | |||||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
graph.Dispose(); | |||||
s.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -7,8 +7,32 @@ using Buffer = Tensorflow.Buffer; | |||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
{ | { | ||||
/// <summary> | |||||
/// Port from `tensorflow\c\c_test_util.cc` | |||||
/// </summary> | |||||
public static class c_test_util | public static class c_test_util | ||||
{ | { | ||||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | |||||
{ | |||||
Operation op = null; | |||||
AddOpHelper(l, r, graph, s, name, ref op, true); | |||||
return op; | |||||
} | |||||
public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) | |||||
{ | |||||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
c_api.TF_AddInputList(desc, new TF_Output[] | |||||
{ | |||||
new TF_Output(l, 0), | |||||
new TF_Output(r, 0), | |||||
}, 2); | |||||
op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
} | |||||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
{ | { | ||||
var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
@@ -58,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
return op; | return op; | ||||
} | } | ||||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | |||||
{ | { | ||||
return Const(new Tensor(v), graph, s, name); | return Const(new Tensor(v), graph, s, name); | ||||
} | } | ||||