using NumSharp.Core; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; using Tensorflow; namespace TensorFlowNET.UnitTest { /// /// tensorflow\c\c_test_util.cc /// TEST(CAPI, Session) /// public class CSession { private IntPtr session_; private List inputs_ = new List(); private List input_values_ = new List(); private List outputs_ = new List(); private List output_values_ = new List(); private List targets_ = new List(); public CSession(Graph graph, Status s, bool user_XLA = false) { var opts = new SessionOptions(); opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); session_ = new Session(graph, opts, s); } public void SetInputs(Dictionary inputs) { DeleteInputValues(); inputs_.Clear(); foreach (var input in inputs) { inputs_.Add(new TF_Output(input.Key, 0)); input_values_.Add(input.Value); } } private void DeleteInputValues() { for (var i = 0; i < input_values_.Count; ++i) { input_values_[i].Dispose(); } input_values_.Clear(); } public void SetOutputs(TF_Output[] outputs) { ResetOutputValues(); outputs_.Clear(); foreach (var output in outputs) { outputs_.Add(output); output_values_.Add(IntPtr.Zero); } } private void ResetOutputValues() { for (var i = 0; i < output_values_.Count; ++i) { if (output_values_[i] != IntPtr.Zero) output_values_[i].Dispose(); } output_values_.Clear(); } public unsafe void Run(Status s) { var inputs_ptr = inputs_.ToArray(); var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); var outputs_ptr = outputs_.ToArray(); var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); IntPtr[] targets_ptr = new IntPtr[0]; c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, outputs_ptr, output_values_ptr, outputs_.Count, targets_ptr, targets_.Count, IntPtr.Zero, s); s.Check(); output_values_[0] = output_values_ptr[0]; } public IntPtr output_tensor(int i) { return output_values_[i]; } public void CloseAndDelete(Status s) { DeleteInputValues(); ResetOutputValues(); } } }