Browse Source

TEST(CAPI, Session) port completed.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
4c912dfe6a
2 changed files with 48 additions and 20 deletions
  1. +19
    -18
      test/TensorFlowNET.UnitTest/CSession.cs
  2. +29
    -2
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 19
- 18
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -1,6 +1,7 @@
using NumSharp.Core; using NumSharp.Core;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
@@ -9,15 +10,16 @@ namespace TensorFlowNET.UnitTest
{ {
/// <summary> /// <summary>
/// tensorflow\c\c_test_util.cc /// tensorflow\c\c_test_util.cc
/// TEST(CAPI, Session)
/// </summary> /// </summary>
public class CSession public class CSession
{ {
private IntPtr session_; private IntPtr session_;


private List<TF_Output> inputs_ = new List<TF_Output>(); private List<TF_Output> inputs_ = new List<TF_Output>();
private List<IntPtr> input_values_ = new List<IntPtr>();
private List<Tensor> input_values_ = new List<Tensor>();
private List<TF_Output> outputs_ = new List<TF_Output>(); private List<TF_Output> outputs_ = new List<TF_Output>();
private List<IntPtr> output_values_ = new List<IntPtr>();
private List<Tensor> output_values_ = new List<Tensor>();


private List<IntPtr> targets_ = new List<IntPtr>(); private List<IntPtr> targets_ = new List<IntPtr>();


@@ -27,17 +29,13 @@ namespace TensorFlowNET.UnitTest
session_ = new Session(graph, opts, s); session_ = new Session(graph, opts, s);
} }


public void SetInputs(Dictionary<IntPtr, IntPtr> inputs)
public void SetInputs(Dictionary<Operation, Tensor> inputs)
{ {
DeleteInputValues(); DeleteInputValues();
inputs_.Clear(); inputs_.Clear();
foreach (var input in inputs) foreach (var input in inputs)
{ {
var i = new TF_Output(input.Key, 0);
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
Marshal.StructureToPtr(i, handle, false);
inputs_.Add(i);

inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value); input_values_.Add(input.Value);
} }
} }
@@ -46,7 +44,7 @@ namespace TensorFlowNET.UnitTest
{ {
for (var i = 0; i < input_values_.Count; ++i) for (var i = 0; i < input_values_.Count; ++i)
{ {
//input_values_[i].Dispose();
input_values_[i].Dispose();
} }
input_values_.Clear(); input_values_.Clear();
} }
@@ -57,10 +55,7 @@ namespace TensorFlowNET.UnitTest
outputs_.Clear(); outputs_.Clear();
foreach (var output in outputs) foreach (var output in outputs)
{ {
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
Marshal.StructureToPtr(new TF_Output(output, 0), handle, true);
outputs_.Add(new TF_Output(output, 0)); outputs_.Add(new TF_Output(output, 0));
handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>());
output_values_.Add(IntPtr.Zero); output_values_.Add(IntPtr.Zero);
} }
} }
@@ -69,18 +64,18 @@ namespace TensorFlowNET.UnitTest
{ {
for (var i = 0; i < output_values_.Count; ++i) for (var i = 0; i < output_values_.Count; ++i)
{ {
//if (output_values_[i] != IntPtr.Zero)
//output_values_[i].Dispose();
if (output_values_[i] != IntPtr.Zero)
output_values_[i].Dispose();
} }
output_values_.Clear(); output_values_.Clear();
} }


public unsafe void Run(Status s) public unsafe void Run(Status s)
{ {
var inputs_ptr = inputs_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
var input_values_ptr = input_values_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
var outputs_ptr = outputs_.ToArray();// outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
var output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
IntPtr targets_ptr = IntPtr.Zero; IntPtr targets_ptr = IntPtr.Zero;


c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1, c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
@@ -97,5 +92,11 @@ namespace TensorFlowNET.UnitTest
{ {
return output_values_[i]; return output_values_[i];
} }

public void CloseAndDelete(Status s)
{
DeleteInputValues();
ResetOutputValues();
}
} }
} }

+ 29
- 2
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -32,8 +32,8 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


// Run the graph. // Run the graph.
var inputs = new Dictionary<IntPtr, IntPtr>();
inputs.Add(feed, c_test_util.Int32Tensor(3));
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs); csession.SetInputs(inputs);


var outputs = new List<IntPtr> { add }; var outputs = new List<IntPtr> { add };
@@ -46,6 +46,33 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.Data<int>(); var output_contents = outTensor.Data<int>();
EXPECT_EQ(3 + 2, output_contents[0]); EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new List<IntPtr> { neg };
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.Data<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
graph.Dispose();
s.Dispose();
} }
} }
} }

Loading…
Cancel
Save