diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index 10ffdab1..ef6d469e 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -1,6 +1,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using System.Text; using Tensorflow; @@ -9,15 +10,16 @@ namespace TensorFlowNET.UnitTest { /// /// tensorflow\c\c_test_util.cc + /// TEST(CAPI, Session) /// public class CSession { private IntPtr session_; private List inputs_ = new List(); - private List input_values_ = new List(); + private List input_values_ = new List(); private List outputs_ = new List(); - private List output_values_ = new List(); + private List output_values_ = new List(); private List targets_ = new List(); @@ -27,17 +29,13 @@ namespace TensorFlowNET.UnitTest session_ = new Session(graph, opts, s); } - public void SetInputs(Dictionary inputs) + public void SetInputs(Dictionary inputs) { DeleteInputValues(); inputs_.Clear(); foreach (var input in inputs) { - var i = new TF_Output(input.Key, 0); - var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); - Marshal.StructureToPtr(i, handle, false); - inputs_.Add(i); - + inputs_.Add(new TF_Output(input.Key, 0)); input_values_.Add(input.Value); } } @@ -46,7 +44,7 @@ namespace TensorFlowNET.UnitTest { for (var i = 0; i < input_values_.Count; ++i) { - //input_values_[i].Dispose(); + input_values_[i].Dispose(); } input_values_.Clear(); } @@ -57,10 +55,7 @@ namespace TensorFlowNET.UnitTest outputs_.Clear(); foreach (var output in outputs) { - var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); - Marshal.StructureToPtr(new TF_Output(output, 0), handle, true); outputs_.Add(new TF_Output(output, 0)); - handle = Marshal.AllocHGlobal(Marshal.SizeOf()); output_values_.Add(IntPtr.Zero); } } @@ -69,18 +64,18 @@ namespace TensorFlowNET.UnitTest { for (var i = 0; i < output_values_.Count; ++i) { - //if (output_values_[i] != IntPtr.Zero) - //output_values_[i].Dispose(); + if (output_values_[i] != IntPtr.Zero) + output_values_[i].Dispose(); } output_values_.Clear(); } public unsafe void Run(Status s) { - var inputs_ptr = inputs_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : inputs_[0]; - var input_values_ptr = input_values_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : input_values_[0]; - var outputs_ptr = outputs_.ToArray();// outputs_.Count == 0 ? IntPtr.Zero : outputs_[0]; - var output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0]; + var inputs_ptr = inputs_.ToArray(); + var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); + var outputs_ptr = outputs_.ToArray(); + var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); IntPtr targets_ptr = IntPtr.Zero; c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1, @@ -97,5 +92,11 @@ namespace TensorFlowNET.UnitTest { return output_values_[i]; } + + public void CloseAndDelete(Status s) + { + DeleteInputValues(); + ResetOutputValues(); + } } } diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 13440741..3d79d747 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -32,8 +32,8 @@ namespace TensorFlowNET.UnitTest ASSERT_EQ(TF_Code.TF_OK, s.Code); // Run the graph. - var inputs = new Dictionary(); - inputs.Add(feed, c_test_util.Int32Tensor(3)); + var inputs = new Dictionary(); + inputs.Add(feed, new Tensor(3)); csession.SetInputs(inputs); var outputs = new List { add }; @@ -46,6 +46,33 @@ namespace TensorFlowNET.UnitTest ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); var output_contents = outTensor.Data(); EXPECT_EQ(3 + 2, output_contents[0]); + + // Add another operation to the graph. + var neg = c_test_util.Neg(add, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run up to the new operation. + inputs = new Dictionary(); + inputs.Add(feed, new Tensor(7)); + csession.SetInputs(inputs); + outputs = new List { neg }; + csession.SetOutputs(outputs); + csession.Run(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + outTensor = csession.output_tensor(0); + ASSERT_TRUE(outTensor != IntPtr.Zero); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); // scalar + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + output_contents = outTensor.Data(); + EXPECT_EQ(-(7 + 2), output_contents[0]); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + graph.Dispose(); + s.Dispose(); } } }