You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CSession.cs 3.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// tensorflow\c\c_test_util.cc
  12. /// TEST(CAPI, Session)
  13. /// </summary>
  14. public class CSession
  15. {
  16. private IntPtr session_;
  17. private List<TF_Output> inputs_ = new List<TF_Output>();
  18. private List<Tensor> input_values_ = new List<Tensor>();
  19. private List<TF_Output> outputs_ = new List<TF_Output>();
  20. private List<Tensor> output_values_ = new List<Tensor>();
  21. private List<IntPtr> targets_ = new List<IntPtr>();
  22. public CSession(Graph graph, Status s, bool user_XLA = false)
  23. {
  24. var opts = new SessionOptions();
  25. opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 });
  26. session_ = new Session(graph, opts, s);
  27. }
  28. public void SetInputs(Dictionary<Operation, Tensor> inputs)
  29. {
  30. DeleteInputValues();
  31. inputs_.Clear();
  32. foreach (var input in inputs)
  33. {
  34. inputs_.Add(new TF_Output(input.Key, 0));
  35. input_values_.Add(input.Value);
  36. }
  37. }
  38. private void DeleteInputValues()
  39. {
  40. for (var i = 0; i < input_values_.Count; ++i)
  41. {
  42. input_values_[i].Dispose();
  43. }
  44. input_values_.Clear();
  45. }
  46. public void SetOutputs(TF_Output[] outputs)
  47. {
  48. ResetOutputValues();
  49. outputs_.Clear();
  50. foreach (var output in outputs)
  51. {
  52. outputs_.Add(output);
  53. output_values_.Add(IntPtr.Zero);
  54. }
  55. }
  56. private void ResetOutputValues()
  57. {
  58. for (var i = 0; i < output_values_.Count; ++i)
  59. {
  60. if (output_values_[i] != IntPtr.Zero)
  61. output_values_[i].Dispose();
  62. }
  63. output_values_.Clear();
  64. }
  65. public unsafe void Run(Status s)
  66. {
  67. var inputs_ptr = inputs_.ToArray();
  68. var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
  69. var outputs_ptr = outputs_.ToArray();
  70. var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
  71. IntPtr[] targets_ptr = new IntPtr[0];
  72. c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
  73. outputs_ptr, output_values_ptr, outputs_.Count,
  74. targets_ptr, targets_.Count,
  75. IntPtr.Zero, s);
  76. s.Check();
  77. output_values_[0] = output_values_ptr[0];
  78. }
  79. public IntPtr output_tensor(int i)
  80. {
  81. return output_values_[i];
  82. }
  83. public void CloseAndDelete(Status s)
  84. {
  85. DeleteInputValues();
  86. ResetOutputValues();
  87. }
  88. }
  89. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。