Browse Source

TEST(CAPI, Session) port completed.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
4c912dfe6a
2 changed files with 48 additions and 20 deletions
  1. +19
    -18
      test/TensorFlowNET.UnitTest/CSession.cs
  2. +29
    -2
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 19
- 18
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -1,6 +1,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;
@@ -9,15 +10,16 @@ namespace TensorFlowNET.UnitTest
{
/// <summary>
/// tensorflow\c\c_test_util.cc
/// TEST(CAPI, Session)
/// </summary>
public class CSession
{
private IntPtr session_;

private List<TF_Output> inputs_ = new List<TF_Output>();
private List<IntPtr> input_values_ = new List<IntPtr>();
private List<Tensor> input_values_ = new List<Tensor>();
private List<TF_Output> outputs_ = new List<TF_Output>();
private List<IntPtr> output_values_ = new List<IntPtr>();
private List<Tensor> output_values_ = new List<Tensor>();

private List<IntPtr> targets_ = new List<IntPtr>();

@@ -27,17 +29,13 @@ namespace TensorFlowNET.UnitTest
session_ = new Session(graph, opts, s);
}

public void SetInputs(Dictionary<IntPtr, IntPtr> inputs)
public void SetInputs(Dictionary<Operation, Tensor> inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
var i = new TF_Output(input.Key, 0);
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
Marshal.StructureToPtr(i, handle, false);
inputs_.Add(i);

inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value);
}
}
@@ -46,7 +44,7 @@ namespace TensorFlowNET.UnitTest
{
for (var i = 0; i < input_values_.Count; ++i)
{
//input_values_[i].Dispose();
input_values_[i].Dispose();
}
input_values_.Clear();
}
@@ -57,10 +55,7 @@ namespace TensorFlowNET.UnitTest
outputs_.Clear();
foreach (var output in outputs)
{
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
Marshal.StructureToPtr(new TF_Output(output, 0), handle, true);
outputs_.Add(new TF_Output(output, 0));
handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>());
output_values_.Add(IntPtr.Zero);
}
}
@@ -69,18 +64,18 @@ namespace TensorFlowNET.UnitTest
{
for (var i = 0; i < output_values_.Count; ++i)
{
//if (output_values_[i] != IntPtr.Zero)
//output_values_[i].Dispose();
if (output_values_[i] != IntPtr.Zero)
output_values_[i].Dispose();
}
output_values_.Clear();
}

public unsafe void Run(Status s)
{
var inputs_ptr = inputs_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
var input_values_ptr = input_values_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
var outputs_ptr = outputs_.ToArray();// outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
var output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
IntPtr targets_ptr = IntPtr.Zero;

c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
@@ -97,5 +92,11 @@ namespace TensorFlowNET.UnitTest
{
return output_values_[i];
}

public void CloseAndDelete(Status s)
{
DeleteInputValues();
ResetOutputValues();
}
}
}

+ 29
- 2
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -32,8 +32,8 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run the graph.
var inputs = new Dictionary<IntPtr, IntPtr>();
inputs.Add(feed, c_test_util.Int32Tensor(3));
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);

var outputs = new List<IntPtr> { add };
@@ -46,6 +46,33 @@ namespace TensorFlowNET.UnitTest
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.Data<int>();
EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new List<IntPtr> { neg };
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.Data<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
graph.Dispose();
s.Dispose();
}
}
}

Loading…
Cancel
Save