using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.Util;
namespace TensorFlowNET.UnitTest
{
///
/// tensorflow\c\c_test_util.cc
/// TEST(CAPI, Session)
///
public class CSession
{
private IntPtr session_;
private List inputs_ = new List();
private List input_values_ = new List();
private List outputs_ = new List();
private List output_values_ = new List();
private List targets_ = new List();
public CSession(Graph graph, Status s, bool user_XLA = false)
{
lock (Locks.ProcessWide)
{
var config = new ConfigProto {InterOpParallelismThreads = 4};
session_ = new Session(graph, config, s);
}
}
public void SetInputs(Dictionary inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value);
}
}
public void SetInputs(KeyValuePair[] inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value);
}
}
private void DeleteInputValues()
{
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
input_values_.Clear();
}
public void SetOutputs(TF_Output[] outputs)
{
ResetOutputValues();
outputs_.Clear();
foreach (var output in outputs)
{
outputs_.Add(output);
output_values_.Add(IntPtr.Zero);
}
}
private void ResetOutputValues()
{
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
output_values_.Clear();
}
public unsafe void Run(Status s)
{
var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray();
var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
IntPtr[] targets_ptr = new IntPtr[0];
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count,
IntPtr.Zero, s.Handle);
s.Check();
for (var i = 0; i < outputs_.Count; i++)
output_values_[i] = output_values_ptr[i];
}
public IntPtr output_tensor(int i)
{
return output_values_[i];
}
public void CloseAndDelete(Status s)
{
DeleteInputValues();
ResetOutputValues();
}
}
}