diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs
index 10ffdab1..ef6d469e 100644
--- a/test/TensorFlowNET.UnitTest/CSession.cs
+++ b/test/TensorFlowNET.UnitTest/CSession.cs
@@ -1,6 +1,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;
@@ -9,15 +10,16 @@ namespace TensorFlowNET.UnitTest
{
///
/// tensorflow\c\c_test_util.cc
+ /// TEST(CAPI, Session)
///
public class CSession
{
private IntPtr session_;
private List inputs_ = new List();
- private List input_values_ = new List();
+ private List input_values_ = new List();
private List outputs_ = new List();
- private List output_values_ = new List();
+ private List output_values_ = new List();
private List targets_ = new List();
@@ -27,17 +29,13 @@ namespace TensorFlowNET.UnitTest
session_ = new Session(graph, opts, s);
}
- public void SetInputs(Dictionary inputs)
+ public void SetInputs(Dictionary inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
- var i = new TF_Output(input.Key, 0);
- var handle = Marshal.AllocHGlobal(Marshal.SizeOf());
- Marshal.StructureToPtr(i, handle, false);
- inputs_.Add(i);
-
+ inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value);
}
}
@@ -46,7 +44,7 @@ namespace TensorFlowNET.UnitTest
{
for (var i = 0; i < input_values_.Count; ++i)
{
- //input_values_[i].Dispose();
+ input_values_[i].Dispose();
}
input_values_.Clear();
}
@@ -57,10 +55,7 @@ namespace TensorFlowNET.UnitTest
outputs_.Clear();
foreach (var output in outputs)
{
- var handle = Marshal.AllocHGlobal(Marshal.SizeOf());
- Marshal.StructureToPtr(new TF_Output(output, 0), handle, true);
outputs_.Add(new TF_Output(output, 0));
- handle = Marshal.AllocHGlobal(Marshal.SizeOf());
output_values_.Add(IntPtr.Zero);
}
}
@@ -69,18 +64,18 @@ namespace TensorFlowNET.UnitTest
{
for (var i = 0; i < output_values_.Count; ++i)
{
- //if (output_values_[i] != IntPtr.Zero)
- //output_values_[i].Dispose();
+ if (output_values_[i] != IntPtr.Zero)
+ output_values_[i].Dispose();
}
output_values_.Clear();
}
public unsafe void Run(Status s)
{
- var inputs_ptr = inputs_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
- var input_values_ptr = input_values_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
- var outputs_ptr = outputs_.ToArray();// outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
- var output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
+ var inputs_ptr = inputs_.ToArray();
+ var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
+ var outputs_ptr = outputs_.ToArray();
+ var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
IntPtr targets_ptr = IntPtr.Zero;
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
@@ -97,5 +92,11 @@ namespace TensorFlowNET.UnitTest
{
return output_values_[i];
}
+
+ public void CloseAndDelete(Status s)
+ {
+ DeleteInputValues();
+ ResetOutputValues();
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs
index 13440741..3d79d747 100644
--- a/test/TensorFlowNET.UnitTest/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/SessionTest.cs
@@ -32,8 +32,8 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ(TF_Code.TF_OK, s.Code);
// Run the graph.
- var inputs = new Dictionary();
- inputs.Add(feed, c_test_util.Int32Tensor(3));
+ var inputs = new Dictionary();
+ inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);
var outputs = new List { add };
@@ -46,6 +46,33 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.Data();
EXPECT_EQ(3 + 2, output_contents[0]);
+
+ // Add another operation to the graph.
+ var neg = c_test_util.Neg(add, graph, s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+
+ // Run up to the new operation.
+ inputs = new Dictionary();
+ inputs.Add(feed, new Tensor(7));
+ csession.SetInputs(inputs);
+ outputs = new List { neg };
+ csession.SetOutputs(outputs);
+ csession.Run(s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+
+ outTensor = csession.output_tensor(0);
+ ASSERT_TRUE(outTensor != IntPtr.Zero);
+ EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
+ EXPECT_EQ(0, outTensor.NDims); // scalar
+ ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
+ output_contents = outTensor.Data();
+ EXPECT_EQ(-(7 + 2), output_contents[0]);
+
+ // Clean up
+ csession.CloseAndDelete(s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+ graph.Dispose();
+ s.Dispose();
}
}
}