You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CSession.cs 3.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow;
  5. using Tensorflow.Util;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. /// <summary>
  9. /// tensorflow\c\c_test_util.cc
  10. /// TEST(CAPI, Session)
  11. /// </summary>
  12. public class CSession
  13. {
  14. private IntPtr session_;
  15. private List<TF_Output> inputs_ = new List<TF_Output>();
  16. private List<Tensor> input_values_ = new List<Tensor>();
  17. private List<TF_Output> outputs_ = new List<TF_Output>();
  18. private List<Tensor> output_values_ = new List<Tensor>();
  19. private List<IntPtr> targets_ = new List<IntPtr>();
  20. public CSession(Graph graph, Status s, bool user_XLA = false)
  21. {
  22. lock (Locks.ProcessWide)
  23. {
  24. var opts = new SessionOptions();
  25. opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4});
  26. session_ = new Session(graph, opts, s);
  27. }
  28. }
  29. public void SetInputs(Dictionary<Operation, Tensor> inputs)
  30. {
  31. DeleteInputValues();
  32. inputs_.Clear();
  33. foreach (var input in inputs)
  34. {
  35. inputs_.Add(new TF_Output(input.Key, 0));
  36. input_values_.Add(input.Value);
  37. }
  38. }
  39. private void DeleteInputValues()
  40. {
  41. //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
  42. input_values_.Clear();
  43. }
  44. public void SetOutputs(TF_Output[] outputs)
  45. {
  46. ResetOutputValues();
  47. outputs_.Clear();
  48. foreach (var output in outputs)
  49. {
  50. outputs_.Add(output);
  51. output_values_.Add(IntPtr.Zero);
  52. }
  53. }
  54. private void ResetOutputValues()
  55. {
  56. //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
  57. output_values_.Clear();
  58. }
  59. public unsafe void Run(Status s)
  60. {
  61. var inputs_ptr = inputs_.ToArray();
  62. var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray();
  63. var outputs_ptr = outputs_.ToArray();
  64. var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
  65. IntPtr[] targets_ptr = new IntPtr[0];
  66. c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
  67. outputs_ptr, output_values_ptr, outputs_.Count,
  68. targets_ptr, targets_.Count,
  69. IntPtr.Zero, s);
  70. s.Check();
  71. output_values_[0] = output_values_ptr[0];
  72. }
  73. public IntPtr output_tensor(int i)
  74. {
  75. return output_values_[i];
  76. }
  77. public void CloseAndDelete(Status s)
  78. {
  79. DeleteInputValues();
  80. ResetOutputValues();
  81. }
  82. }
  83. }