|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using Tensorflow.Util;
-
- namespace Tensorflow.Native.UnitTest
- {
- /// <summary>
- /// tensorflow\c\c_test_util.cc
- /// TEST(CAPI, Session)
- /// </summary>
- public class CSession
- {
- private SafeSessionHandle session_;
-
- private List<TF_Output> inputs_ = new List<TF_Output>();
- private List<Tensor> input_values_ = new List<Tensor>();
- private List<TF_Output> outputs_ = new List<TF_Output>();
- private List<Tensor> output_values_ = new List<Tensor>();
-
- private List<IntPtr> targets_ = new List<IntPtr>();
-
- public CSession(Graph graph, Status s, bool user_XLA = false)
- {
- var config = new ConfigProto { InterOpParallelismThreads = 4 };
- session_ = new Session(graph, config, s);
- }
-
- public void SetInputs(Dictionary<Operation, Tensor> inputs)
- {
- DeleteInputValues();
- inputs_.Clear();
- foreach (var input in inputs)
- {
- inputs_.Add(new TF_Output(input.Key, 0));
- input_values_.Add(input.Value);
- }
- }
-
- public void SetInputs(KeyValuePair<Operation, Tensor>[] inputs)
- {
- DeleteInputValues();
- inputs_.Clear();
- foreach (var input in inputs)
- {
- inputs_.Add(new TF_Output(input.Key, 0));
- input_values_.Add(input.Value);
- }
- }
-
- private void DeleteInputValues()
- {
- //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
- input_values_.Clear();
- }
-
- public void SetOutputs(TF_Output[] outputs)
- {
- ResetOutputValues();
- outputs_.Clear();
- foreach (var output in outputs)
- {
- outputs_.Add(output);
- output_values_.Add(null);
- }
- }
-
- private void ResetOutputValues()
- {
- //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
- output_values_.Clear();
- }
-
- public unsafe void Run(Status s)
- {
- var inputs_ptr = inputs_.ToArray();
- var input_values_ptr = input_values_.Select(x => x.Handle.DangerousGetHandle()).ToArray();
- var outputs_ptr = outputs_.ToArray();
- var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
- IntPtr[] targets_ptr = new IntPtr[0];
-
- c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
- outputs_ptr, output_values_ptr, outputs_.Count,
- targets_ptr, targets_.Count,
- IntPtr.Zero, s);
-
- s.Check();
-
- for (var i = 0; i < outputs_.Count; i++)
- output_values_[i] = new SafeTensorHandle(output_values_ptr[i]);
- }
-
- public SafeTensorHandle output_tensor(int i)
- {
- return output_values_[i].Handle;
- }
-
- public void CloseAndDelete(Status s)
- {
- DeleteInputValues();
- ResetOutputValues();
- }
- }
- }
|