You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSession.cs 7.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public class BaseSession
  10. {
  11. protected Graph _graph;
  12. protected bool _opened;
  13. protected bool _closed;
  14. protected int _current_version;
  15. protected byte[] _target;
  16. protected IntPtr _session;
  17. public BaseSession(string target = "", Graph graph = null)
  18. {
  19. if(graph is null)
  20. {
  21. _graph = ops.get_default_graph();
  22. }
  23. else
  24. {
  25. _graph = graph;
  26. }
  27. _target = UTF8Encoding.UTF8.GetBytes(target);
  28. var opts = c_api.TF_NewSessionOptions();
  29. var status = new Status();
  30. _session = c_api.TF_NewSession(_graph, opts, status);
  31. c_api.TF_DeleteSessionOptions(opts);
  32. }
  33. public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
  34. {
  35. return _run(fetches, feed_dict);
  36. }
  37. private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
  38. {
  39. var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
  40. if (feed_dict != null)
  41. feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));
  42. // Create a fetch handler to take care of the structure of fetches.
  43. var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor);
  44. // Run request and get response.
  45. // We need to keep the returned movers alive for the following _do_run().
  46. // These movers are no longer needed when _do_run() completes, and
  47. // are deleted when `movers` goes out of scope when this _run() ends.
  48. var _ = _update_with_movers();
  49. var final_fetches = fetch_handler.fetches();
  50. var final_targets = fetch_handler.targets();
  51. // We only want to really perform the run if fetches or targets are provided,
  52. // or if the call is a partial run that specifies feeds.
  53. var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
  54. return fetch_handler.build_results(this, results);
  55. }
  56. /// <summary>
  57. /// Runs a step based on the given fetches and feeds.
  58. /// </summary>
  59. /// <typeparam name="T"></typeparam>
  60. /// <param name="target_list">A list of operations to be run, but not fetched.</param>
  61. /// <param name="fetch_list"></param>
  62. /// <param name="feed_dict"></param>
  63. /// <returns>
  64. /// A list of numpy ndarrays, corresponding to the elements of
  65. /// `fetch_list`. If the ith element of `fetch_list` contains the
  66. /// name of an operation, the first Tensor output of that operation
  67. /// will be returned for that element.
  68. /// </returns>
  69. private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
  70. {
  71. var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
  72. var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
  73. var targets = target_list;
  74. return _call_tf_sessionrun(feeds, fetches, target_list);
  75. }
  76. private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
  77. {
  78. // Ensure any changes to the graph are reflected in the runtime.
  79. _extend_graph();
  80. var status = new Status();
  81. var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
  82. c_api.TF_SessionRun(_session,
  83. run_options: null,
  84. inputs: feed_dict.Select(f => f.Key).ToArray(),
  85. input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
  86. ninputs: feed_dict.Length,
  87. outputs: fetch_list,
  88. output_values: output_values,
  89. noutputs: fetch_list.Length,
  90. target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
  91. ntargets: target_list.Count,
  92. run_metadata: IntPtr.Zero,
  93. status: status);
  94. status.Check(true);
  95. var result = new NDArray[fetch_list.Length];
  96. for (int i = 0; i < fetch_list.Length; i++)
  97. {
  98. result[i] = fetchValue(output_values[i]);
  99. }
  100. return result;
  101. }
  102. private unsafe NDArray fetchValue(IntPtr output)
  103. {
  104. var tensor = new Tensor(output);
  105. NDArray nd = null;
  106. Type type = tensor.dtype.as_numpy_datatype();
  107. var ndims = tensor.shape.Select(x => (int)x).ToArray();
  108. switch (tensor.dtype)
  109. {
  110. case TF_DataType.TF_STRING:
  111. var bytes = tensor.Data();
  112. // wired, don't know why we have to start from offset 9.
  113. var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
  114. nd = np.array(str).reshape();
  115. break;
  116. case TF_DataType.TF_INT16:
  117. var shorts = new short[tensor.size];
  118. for (ulong i = 0; i < tensor.size; i++)
  119. shorts[i] = *(short*)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
  120. nd = np.array(shorts).reshape(ndims);
  121. break;
  122. case TF_DataType.TF_INT32:
  123. var ints = new int[tensor.size];
  124. for (ulong i = 0; i < tensor.size; i++)
  125. ints[i] = *(int*)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
  126. nd = np.array(ints).reshape(ndims);
  127. break;
  128. case TF_DataType.TF_FLOAT:
  129. var floats = new float[tensor.size];
  130. for (ulong i = 0; i < tensor.size; i++)
  131. floats[i] = *(float*)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
  132. nd = np.array(floats).reshape(ndims);
  133. break;
  134. case TF_DataType.TF_DOUBLE:
  135. var doubles = new double[tensor.size];
  136. for (ulong i = 0; i < tensor.size; i++)
  137. doubles[i] = *(double*)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
  138. nd = np.array(doubles).reshape(ndims);
  139. break;
  140. default:
  141. throw new NotImplementedException("can't fetch output");
  142. }
  143. return nd;
  144. }
  145. /// <summary>
  146. /// If a tensor handle that is fed to a device incompatible placeholder,
  147. /// we move the tensor to the right device, generate a new tensor handle,
  148. /// and update feed_dict to use the new handle.
  149. /// </summary>
  150. private List<object> _update_with_movers()
  151. {
  152. return new List<object> { };
  153. }
  154. private void _extend_graph()
  155. {
  156. }
  157. }
  158. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。