using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;
namespace TensorFlowNET.UnitTest
{
///
/// tensorflow\c\c_test_util.cc
///
public class CSession
{
private IntPtr session_;
private List inputs_ = new List();
private List input_values_ = new List();
private List outputs_ = new List();
private List output_values_ = new List();
private List targets_ = new List();
public CSession(Graph graph, Status s, bool user_XLA = false)
{
var opts = new SessionOptions();
session_ = new Session(graph, opts, s);
}
public void SetInputs(Dictionary inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
var handle = Marshal.AllocHGlobal(Marshal.SizeOf());
Marshal.StructureToPtr(new TF_Output(input.Key, 0), handle, false);
inputs_.Add(handle);
input_values_.Add(input.Value);
}
}
private void DeleteInputValues()
{
for (var i = 0; i < input_values_.Count; ++i)
{
//input_values_[i].Dispose();
}
input_values_.Clear();
}
public void SetOutputs(List outputs)
{
ResetOutputValues();
outputs_.Clear();
foreach (var output in outputs)
{
var handle = Marshal.AllocHGlobal(Marshal.SizeOf());
Marshal.StructureToPtr(new TF_Output(output, 0), handle, true);
outputs_.Add(handle);
handle = Marshal.AllocHGlobal(Marshal.SizeOf());
output_values_.Add(IntPtr.Zero);
}
}
private void ResetOutputValues()
{
for (var i = 0; i < output_values_.Count; ++i)
{
//if (output_values_[i] != IntPtr.Zero)
//output_values_[i].Dispose();
}
output_values_.Clear();
}
public unsafe void Run(Status s)
{
IntPtr inputs_ptr = inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
IntPtr input_values_ptr = inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
IntPtr outputs_ptr = outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
IntPtr output_values_ptr = output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
IntPtr targets_ptr = IntPtr.Zero;
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_.Count,
outputs_ptr, ref output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count,
IntPtr.Zero, s);
s.Check();
output_values_[0] = output_values_ptr;
}
public IntPtr output_tensor(int i)
{
return output_values_[i];
}
}
}