You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CSession.cs 3.3 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Util;
  5. namespace Tensorflow.Native.UnitTest
  6. {
  7. /// <summary>
  8. /// tensorflow\c\c_test_util.cc
  9. /// TEST(CAPI, Session)
  10. /// </summary>
  11. public class CSession
  12. {
  13. private IntPtr session_;
  14. private List<TF_Output> inputs_ = new List<TF_Output>();
  15. private List<Tensor> input_values_ = new List<Tensor>();
  16. private List<TF_Output> outputs_ = new List<TF_Output>();
  17. private List<Tensor> output_values_ = new List<Tensor>();
  18. private List<IntPtr> targets_ = new List<IntPtr>();
  19. public CSession(Graph graph, Status s, bool user_XLA = false)
  20. {
  21. lock (Locks.ProcessWide)
  22. {
  23. var config = new ConfigProto { InterOpParallelismThreads = 4 };
  24. session_ = new Session(graph, config, s);
  25. }
  26. }
  27. public void SetInputs(Dictionary<Operation, Tensor> inputs)
  28. {
  29. DeleteInputValues();
  30. inputs_.Clear();
  31. foreach (var input in inputs)
  32. {
  33. inputs_.Add(new TF_Output(input.Key, 0));
  34. input_values_.Add(input.Value);
  35. }
  36. }
  37. public void SetInputs(KeyValuePair<Operation, Tensor>[] inputs)
  38. {
  39. DeleteInputValues();
  40. inputs_.Clear();
  41. foreach (var input in inputs)
  42. {
  43. inputs_.Add(new TF_Output(input.Key, 0));
  44. input_values_.Add(input.Value);
  45. }
  46. }
  47. private void DeleteInputValues()
  48. {
  49. //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
  50. input_values_.Clear();
  51. }
  52. public void SetOutputs(TF_Output[] outputs)
  53. {
  54. ResetOutputValues();
  55. outputs_.Clear();
  56. foreach (var output in outputs)
  57. {
  58. outputs_.Add(output);
  59. output_values_.Add(IntPtr.Zero);
  60. }
  61. }
  62. private void ResetOutputValues()
  63. {
  64. //clearing is enough as they will be disposed by the GC unless they are referenced else-where.
  65. output_values_.Clear();
  66. }
  67. public unsafe void Run(Status s)
  68. {
  69. var inputs_ptr = inputs_.ToArray();
  70. var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
  71. var outputs_ptr = outputs_.ToArray();
  72. var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
  73. IntPtr[] targets_ptr = new IntPtr[0];
  74. c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
  75. outputs_ptr, output_values_ptr, outputs_.Count,
  76. targets_ptr, targets_.Count,
  77. IntPtr.Zero, s.Handle);
  78. s.Check();
  79. for (var i = 0; i < outputs_.Count; i++)
  80. output_values_[i] = output_values_ptr[i];
  81. }
  82. public IntPtr output_tensor(int i)
  83. {
  84. return output_values_[i];
  85. }
  86. public void CloseAndDelete(Status s)
  87. {
  88. DeleteInputValues();
  89. ResetOutputValues();
  90. }
  91. }
  92. }