You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CSession.cs 3.0 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// tensorflow\c\c_test_util.cc
  12. /// TEST(CAPI, Session)
  13. /// </summary>
  14. public class CSession
  15. {
  16. private IntPtr session_;
  17. private List<TF_Output> inputs_ = new List<TF_Output>();
  18. private List<Tensor> input_values_ = new List<Tensor>();
  19. private List<TF_Output> outputs_ = new List<TF_Output>();
  20. private List<Tensor> output_values_ = new List<Tensor>();
  21. private List<IntPtr> targets_ = new List<IntPtr>();
  22. public CSession(Graph graph, Status s, bool user_XLA = false)
  23. {
  24. var opts = new SessionOptions();
  25. session_ = new Session(graph, opts, s);
  26. }
  27. public void SetInputs(Dictionary<Operation, Tensor> inputs)
  28. {
  29. DeleteInputValues();
  30. inputs_.Clear();
  31. foreach (var input in inputs)
  32. {
  33. inputs_.Add(new TF_Output(input.Key, 0));
  34. input_values_.Add(input.Value);
  35. }
  36. }
  37. private void DeleteInputValues()
  38. {
  39. for (var i = 0; i < input_values_.Count; ++i)
  40. {
  41. input_values_[i].Dispose();
  42. }
  43. input_values_.Clear();
  44. }
  45. public void SetOutputs(List<IntPtr> outputs)
  46. {
  47. ResetOutputValues();
  48. outputs_.Clear();
  49. foreach (var output in outputs)
  50. {
  51. outputs_.Add(new TF_Output(output, 0));
  52. output_values_.Add(IntPtr.Zero);
  53. }
  54. }
  55. private void ResetOutputValues()
  56. {
  57. for (var i = 0; i < output_values_.Count; ++i)
  58. {
  59. if (output_values_[i] != IntPtr.Zero)
  60. output_values_[i].Dispose();
  61. }
  62. output_values_.Clear();
  63. }
  64. public unsafe void Run(Status s)
  65. {
  66. var inputs_ptr = inputs_.ToArray();
  67. var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
  68. var outputs_ptr = outputs_.ToArray();
  69. var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
  70. IntPtr targets_ptr = IntPtr.Zero;
  71. c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
  72. outputs_ptr, output_values_ptr, outputs_.Count,
  73. targets_ptr, targets_.Count,
  74. IntPtr.Zero, s);
  75. s.Check();
  76. output_values_[0] = output_values_ptr[0];
  77. }
  78. public IntPtr output_tensor(int i)
  79. {
  80. return output_values_[i];
  81. }
  82. public void CloseAndDelete(Status s)
  83. {
  84. DeleteInputValues();
  85. ResetOutputValues();
  86. }
  87. }
  88. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)