diff --git a/README.md b/README.md index f0c3ed6c..99afcc93 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) -[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/tensorflow-net-p7kmsjyo10ey?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) +[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) [![codecov](https://codecov.io/gh/SciSharp/NumSharp/branch/master/graph/badge.svg)](https://codecov.io/gh/SciSharp/NumSharp) [![NuGet](https://img.shields.io/nuget/dt/TensorFlow.NET.svg)](https://www.nuget.org/packages/TensorFlow.NET) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index b22bad4b..f17551a1 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -12,13 +12,13 @@ namespace Tensorflow { var num_return_outputs = opts.NumReturnOutputs; var return_outputs = new TF_Output[num_return_outputs]; - TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf() * 2); + int size = Marshal.SizeOf(); + TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(size * num_return_outputs); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); for (int i = 0; i < num_return_outputs; i++) { - var handle = return_output_handle + i * Marshal.SizeOf(); - + var handle = return_output_handle + i * size; return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index d9b91c8d..69b4e13c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -140,15 +140,8 @@ namespace Tensorflow return ret; } - public static implicit operator Operation(IntPtr handle) - { - return new Operation(handle); - } - - public static implicit operator IntPtr(Operation op) - { - return op._handle; - } + public static implicit operator Operation(IntPtr handle) => new Operation(handle); + public static implicit operator IntPtr(Operation op) => op._handle; public override bool Equals(object obj) { diff --git a/src/TensorFlowNET.Core/Operations/TF_Output.cs b/src/TensorFlowNET.Core/Operations/TF_Output.cs index 76edd6bb..16d0285a 100644 --- a/src/TensorFlowNET.Core/Operations/TF_Output.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Output.cs @@ -14,7 +14,7 @@ namespace Tensorflow this.index = index; } - public unsafe IntPtr oper; + public IntPtr oper; public int index; } } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3c8a8112..ab17c7f1 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -28,11 +28,11 @@ namespace Tensorflow } _target = UTF8Encoding.UTF8.GetBytes(target); - var opts = c_api.TF_NewSessionOptions(); - var status = new Status(); - _session = c_api.TF_NewSession(_graph, opts, status); + //var opts = c_api.TF_NewSessionOptions(); + //var status = new Status(); + //_session = c_api.TF_NewSession(_graph, opts, status); - c_api.TF_DeleteSessionOptions(opts); + //c_api.TF_DeleteSessionOptions(opts); } public void Dispose() @@ -102,7 +102,7 @@ namespace Tensorflow var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - c_api.TF_SessionRun(_session, + /*c_api.TF_SessionRun(_session, run_options: IntPtr.Zero, inputs: feed_dict.Select(f => f.Key).ToArray(), input_values: new IntPtr[] { }, @@ -113,7 +113,7 @@ namespace Tensorflow target_opers: new IntPtr[] { }, ntargets: 0, run_metadata: IntPtr.Zero, - status: status); + status: status);*/ var result = output_values.Select(x => c_api.TF_TensorData(x)) .Select(x => (object)*(float*)x) diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 20c31f1f..936cf99b 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -6,5 +6,19 @@ namespace Tensorflow { public class Session : BaseSession { + private IntPtr _handle; + + public Session(IntPtr handle) + { + _handle = handle; + } + + public Session(Graph graph, SessionOptions opts, Status s) + { + _handle = c_api.TF_NewSession(graph, opts, s); + } + + public static implicit operator IntPtr(Session session) => session._handle; + public static implicit operator Session(IntPtr handle) => new Session(handle); } } diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs new file mode 100644 index 00000000..72ae890c --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public class SessionOptions : IDisposable + { + private IntPtr _handle; + + public unsafe SessionOptions() + { + var opts = c_api.TF_NewSessionOptions(); + _handle = opts; + } + + public unsafe SessionOptions(IntPtr handle) + { + _handle = handle; + } + + public void Dispose() + { + c_api.TF_DeleteSessionOptions(_handle); + } + + public static implicit operator IntPtr(SessionOptions opts) => opts._handle; + public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle); + } +} diff --git a/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs new file mode 100644 index 00000000..8d8e1d6a --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_SessionOptions + { + public IntPtr options; + } +} diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index a3d07826..29118088 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -10,7 +10,7 @@ namespace Tensorflow /// /// Destroy an options object. /// - /// + /// TF_SessionOptions* [DllImport(TensorFlowLibName)] public static unsafe extern void TF_DeleteSessionOptions(IntPtr opts); @@ -18,41 +18,63 @@ namespace Tensorflow /// Return a new execution session with the associated graph, or NULL on /// error. Does not take ownership of any input parameters. /// - /// - /// - /// - /// + /// TF_Graph* + /// const TF_SessionOptions* + /// TF_Status* + /// TF_Session* [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, IntPtr status); /// /// Return a new options object. /// - /// + /// TF_SessionOptions* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewSessionOptions(); + public static extern unsafe IntPtr TF_NewSessionOptions(); /// /// Run the graph associated with the session starting with the supplied inputs /// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). + /// + /// Any NULL and non-NULL value combinations for (`run_options`, + /// `run_metadata`) are valid. + /// + /// - `run_options` may be NULL, in which case it will be ignored; or + /// non-NULL, in which case it must point to a `TF_Buffer` containing the + /// serialized representation of a `RunOptions` protocol buffer. + /// - `run_metadata` may be NULL, in which case it will be ignored; or + /// non-NULL, in which case it must point to an empty, freshly allocated + /// `TF_Buffer` that may be updated to contain the serialized representation + /// of a `RunMetadata` protocol buffer. + /// + /// The caller retains ownership of `input_values` (which can be deleted using + /// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or + /// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on + /// them. + /// + /// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in + /// output_values[]. Ownership of the elements of output_values[] is transferred + /// to the caller, which must eventually call TF_DeleteTensor on them. + /// + /// On failure, output_values[] contains NULLs. /// - /// - /// - /// TF_Output - /// TF_Tensor - /// - /// - /// - /// - /// - /// - /// - /// + /// TF_Session* + /// const TF_Buffer* + /// const TF_Output* + /// TF_Tensor* const* + /// int + /// const TF_Output* + /// TF_Tensor** + /// int + /// const TF_Operation* const* + /// int + /// TF_Buffer* + /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SessionRun(IntPtr session, IntPtr run_options, - TF_Output[] inputs, IntPtr[] input_values, int ninputs, - TF_Output[] outputs, IntPtr[] output_values, int noutputs, - IntPtr[] target_opers, int ntargets, + public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, + IntPtr inputs, IntPtr input_values, int ninputs, + IntPtr outputs, ref IntPtr output_values, int noutputs, + IntPtr target_opers, int ntargets, IntPtr run_metadata, IntPtr status); } diff --git a/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs b/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs new file mode 100644 index 00000000..3332064a --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_Tensor + { + public TF_DataType dtype; + public IntPtr shape; + public IntPtr buffer; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index c3a634da..54813168 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -22,12 +22,22 @@ namespace Tensorflow public object value; public int value_index { get; } - public TF_DataType dtype { get; } - public ulong bytesize { get; } - public ulong dataTypeSize { get;} - public ulong size => bytesize / dataTypeSize; - public IntPtr buffer { get; } - public long[] shape { get; } + public TF_DataType dtype => _handle == IntPtr.Zero ? TF_DataType.DtInvalid : c_api.TF_TensorType(_handle); + public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); + public ulong dataTypeSize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); + public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; + public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); + public long[] shape + { + get + { + var dims = new long[rank]; + for (int i = 0; i < rank; i++) + shape[i] = c_api.TF_Dim(_handle, i); + + return dims; + } + } /// /// number of dimensions @@ -37,7 +47,7 @@ namespace Tensorflow /// 3 3-Tensor (cube of numbers) /// n n-Tensor (you get the idea) /// - public int rank; + public int rank => _handle == IntPtr.Zero ? 0 : c_api.TF_NumDims(_handle); public int NDims => rank; /// @@ -48,29 +58,12 @@ namespace Tensorflow public Tensor(IntPtr handle) { _handle = handle; - dtype = c_api.TF_TensorType(handle); - rank = c_api.TF_NumDims(handle); - bytesize = c_api.TF_TensorByteSize(handle); - buffer = c_api.TF_TensorData(handle); - dataTypeSize = c_api.TF_DataTypeSize(dtype); - - shape = new long[rank]; - for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(handle, i); } public Tensor(NDArray nd) { _handle = Allocate(nd); - dtype = c_api.TF_TensorType(_handle); - rank = c_api.TF_NumDims(_handle); - bytesize = c_api.TF_TensorByteSize(_handle); - buffer = c_api.TF_TensorData(_handle); - dataTypeSize = c_api.TF_DataTypeSize(dtype); - - shape = new long[rank]; - for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(_handle, i); + value = nd.Data(); } private IntPtr Allocate(NDArray nd) @@ -113,7 +106,6 @@ namespace Tensorflow { this.op = op; this.value_index = value_index; - this.dtype = dtype; } public TF_Output _as_tf_output() @@ -146,6 +138,12 @@ namespace Tensorflow return data; } + public Tensor MaybeMove() + { + var tensor = c_api.TF_TensorMaybeMove(_handle); + return tensor; + } + public TF_DataType ToTFDataType(Type type) { switch (type.Name) diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 741899d9..872bd537 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -73,6 +73,15 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_TensorData(IntPtr tensor); + /// + /// Deletes `tensor` and returns a new TF_Tensor with the same content if + /// possible. Returns nullptr and leaves `tensor` untouched if not. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_TensorMaybeMove(IntPtr tensor); + /// /// Return the type of a tensor element. /// diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 76bbd28e..8c643597 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -51,7 +51,7 @@ namespace Tensorflow public static Session Session() { - return new Session(); + return (Session)new BaseSession(); } } } diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs new file mode 100644 index 00000000..6678dcf4 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -0,0 +1,100 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + /// + /// tensorflow\c\c_test_util.cc + /// + public class CSession + { + private IntPtr session_; + + private List inputs_ = new List(); + private List input_values_ = new List(); + private List outputs_ = new List(); + private List output_values_ = new List(); + + private List targets_ = new List(); + + public CSession(Graph graph, Status s, bool user_XLA = false) + { + var opts = new SessionOptions(); + session_ = new Session(graph, opts, s); + } + + public void SetInputs(Dictionary inputs) + { + DeleteInputValues(); + inputs_.Clear(); + foreach (var input in inputs) + { + var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + Marshal.StructureToPtr(new TF_Output(input.Key, 0), handle, false); + inputs_.Add(handle); + + input_values_.Add(input.Value); + } + } + + private void DeleteInputValues() + { + for (var i = 0; i < input_values_.Count; ++i) + { + //input_values_[i].Dispose(); + } + input_values_.Clear(); + } + + public void SetOutputs(List outputs) + { + ResetOutputValues(); + outputs_.Clear(); + foreach (var output in outputs) + { + var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + Marshal.StructureToPtr(new TF_Output(output, 0), handle, true); + outputs_.Add(handle); + handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + output_values_.Add(IntPtr.Zero); + } + } + + private void ResetOutputValues() + { + for (var i = 0; i < output_values_.Count; ++i) + { + //if (output_values_[i] != IntPtr.Zero) + //output_values_[i].Dispose(); + } + output_values_.Clear(); + } + + public unsafe void Run(Status s) + { + IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; + IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; + IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; + IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; + IntPtr targets_ptr = IntPtr.Zero; + + c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count, + outputs_ptr, ref output_values_ptr, outputs_.Count, + targets_ptr, targets_.Count, + IntPtr.Zero, s); + + s.Check(); + + output_values_[0] = output_values_ptr; + } + + public IntPtr output_tensor(int i) + { + return output_values_[i]; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 00293c2f..30066676 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, Graph)` /// [TestMethod] - public void c_api_Graph() + public void Graph() { var s = new Status(); var graph = new Graph(); @@ -205,7 +205,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, ImportGraphDef)` /// [TestMethod] - public void c_api_ImportGraphDef() + public void ImportGraphDef() { var s = new Status(); var graph = new Graph(); @@ -362,7 +362,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` /// [TestMethod] - public void c_api_ImportGraphDef_WithReturnOutputs() + public void ImportGraphDef_WithReturnOutputs() { var s = new Status(); var graph = new Graph(); @@ -407,5 +407,16 @@ namespace TensorFlowNET.UnitTest graph.Dispose(); s.Dispose(); } + + /// + /// `TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings)` + /// + [TestMethod] + public void ImportGraphDef_MissingUnusedInputMappings() + { + + } + + } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 349d82b3..bfd1d5d8 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, GetAllOpList)` /// [TestMethod] - public void c_api_GetAllOpList() + public void GetAllOpList() { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs new file mode 100644 index 00000000..5ac87608 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -0,0 +1,51 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class SessionTest : CApiTest + { + /// + /// tensorflow\c\c_api_test.cc + /// `TEST(CAPI, Session)` + /// + [TestMethod] + public void Session() + { + var s = new Status(); + var graph = new Graph(); + + // Make a placeholder operation. + var feed = c_test_util.ScalarConst(3, graph, s, "scalar1"); //c_test_util.Placeholder(graph, s); + + // Make a constant operation with the scalar "2". + var two = c_test_util.ScalarConst(2, graph, s, "scalar2"); + + // Add operation. + var add = c_test_util.Add(feed, two, graph, s); + + var csession = new CSession(graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run the graph. + var inputs = new Dictionary(); + inputs.Add(feed, c_test_util.Int32Tensor(3)); + //csession.SetInputs(inputs); + + var outputs = new List { add }; + csession.SetOutputs(outputs); + + csession.Run(s); + Tensor outTensor = csession.output_tensor(0); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + var output_contents = outTensor.Data(); + EXPECT_EQ(3 + 2, output_contents[0]); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 65bab728..aac36d06 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, AllocateTensor)` /// [TestMethod] - public void c_api_AllocateTensor() + public void AllocateTensor() { ulong num_bytes = 6 * sizeof(float); long[] dims = { 2, 3 }; @@ -29,12 +29,26 @@ namespace TensorFlowNET.UnitTest t.Dispose(); } + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, MaybeMove)` + /// + [TestMethod] + public void MaybeMove() + { + NDArray nd = np.array(2, 3); + Tensor t = new Tensor(nd); + Tensor o = t.MaybeMove(); + ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. + t.Dispose(); + } + /// /// Port from c_api_test.cc /// `TEST(CAPI, Tensor)` /// [TestMethod] - public void c_api_Tensor() + public void Tensor() { var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); @@ -54,7 +68,7 @@ namespace TensorFlowNET.UnitTest /// `TEST(CAPI, SetShape)` /// [TestMethod] - public void c_api_SetShape() + public void SetShape() { var s = new Status(); var graph = new Graph(); diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 128e50df..c242299b 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -236,5 +236,19 @@ namespace TensorFlowNET.UnitTest { return Const(new Tensor(v), graph, s, name); } + + public static unsafe IntPtr Int32Tensor(int v) + { + bool deallocator_called = false; + const int num_bytes = sizeof(int); + var dotHandle = Marshal.AllocHGlobal(num_bytes * 1); + *(int*)dotHandle = v; + return c_api.TF_NewTensor(TF_DataType.TF_INT32, new long[0], 0, dotHandle, num_bytes, + (IntPtr values, IntPtr len, ref bool closure) => + { + // Free the original buffer and set flag + Marshal.FreeHGlobal(dotHandle); + }, ref deallocator_called); + } } }