You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSession.cs 10 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public class BaseSession
  10. {
  11. protected Graph _graph;
  12. protected bool _opened;
  13. protected bool _closed;
  14. protected int _current_version;
  15. protected byte[] _target;
  16. protected IntPtr _session;
  17. public BaseSession(string target = "", Graph graph = null)
  18. {
  19. if(graph is null)
  20. {
  21. _graph = ops.get_default_graph();
  22. }
  23. else
  24. {
  25. _graph = graph;
  26. }
  27. _target = UTF8Encoding.UTF8.GetBytes(target);
  28. var opts = c_api.TF_NewSessionOptions();
  29. var status = new Status();
  30. _session = c_api.TF_NewSession(_graph, opts, status);
  31. c_api.TF_DeleteSessionOptions(opts);
  32. }
  33. public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
  34. {
  35. return _run(fetches, feed_dict);
  36. }
  37. private NDArray _run(object fetches, FeedItem[] feed_dict = null)
  38. {
  39. var feed_dict_tensor = new Dictionary<object, object>();
  40. var feed_map = new Dictionary<object, object>();
  41. Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
  42. {
  43. return new (object, object)[] { (item.Key, item.Value) };
  44. };
  45. // Validate and process feed_dict.
  46. if (feed_dict != null)
  47. {
  48. foreach (var feed in feed_dict)
  49. {
  50. foreach (var (subfeed, subfeed_val) in feed_fn(feed))
  51. {
  52. var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
  53. var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
  54. switch (subfeed_val)
  55. {
  56. case IntPtr pointer:
  57. feed_dict_tensor[subfeed_t] = pointer;
  58. break;
  59. case NDArray nd:
  60. feed_dict_tensor[subfeed_t] = nd;
  61. break;
  62. case float floatVal:
  63. feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
  64. break;
  65. case int intVal:
  66. feed_dict_tensor[subfeed_t] = (NDArray)intVal;
  67. break;
  68. case string str:
  69. feed_dict_tensor[subfeed_t] = (NDArray)str;
  70. break;
  71. case byte[] bytes:
  72. feed_dict_tensor[subfeed_t] = (NDArray)bytes;
  73. break;
  74. default:
  75. throw new NotImplementedException("_run subfeed");
  76. }
  77. feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
  78. }
  79. }
  80. }
  81. // Create a fetch handler to take care of the structure of fetches.
  82. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
  83. // Run request and get response.
  84. // We need to keep the returned movers alive for the following _do_run().
  85. // These movers are no longer needed when _do_run() completes, and
  86. // are deleted when `movers` goes out of scope when this _run() ends.
  87. var _ = _update_with_movers();
  88. var final_fetches = fetch_handler.fetches();
  89. var final_targets = fetch_handler.targets();
  90. // We only want to really perform the run if fetches or targets are provided,
  91. // or if the call is a partial run that specifies feeds.
  92. var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
  93. return fetch_handler.build_results(this, results);
  94. }
  95. /// <summary>
  96. /// Runs a step based on the given fetches and feeds.
  97. /// </summary>
  98. /// <typeparam name="T"></typeparam>
  99. /// <param name="target_list">A list of operations to be run, but not fetched.</param>
  100. /// <param name="fetch_list"></param>
  101. /// <param name="feed_dict"></param>
  102. /// <returns>
  103. /// A list of numpy ndarrays, corresponding to the elements of
  104. /// `fetch_list`. If the ith element of `fetch_list` contains the
  105. /// name of an operation, the first Tensor output of that operation
  106. /// will be returned for that element.
  107. /// </returns>
  108. private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
  109. {
  110. var feeds = feed_dict.Select(x =>
  111. {
  112. if(x.Key is Tensor tensor)
  113. {
  114. switch (x.Value)
  115. {
  116. case IntPtr pointer:
  117. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
  118. case Tensor t1:
  119. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
  120. case NDArray nd:
  121. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
  122. case int intVal:
  123. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
  124. case float floatVal:
  125. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
  126. case double doubleVal:
  127. return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
  128. default:
  129. throw new NotImplementedException("feed_dict data type");
  130. }
  131. }
  132. throw new NotImplementedException("_do_run.feed_dict");
  133. }).ToArray();
  134. var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
  135. var targets = target_list;
  136. return _call_tf_sessionrun(feeds, fetches, target_list);
  137. }
  138. private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
  139. {
  140. // Ensure any changes to the graph are reflected in the runtime.
  141. _extend_graph();
  142. var status = new Status();
  143. var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
  144. c_api.TF_SessionRun(_session,
  145. run_options: null,
  146. inputs: feed_dict.Select(f => f.Key).ToArray(),
  147. input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
  148. ninputs: feed_dict.Length,
  149. outputs: fetch_list,
  150. output_values: output_values,
  151. noutputs: fetch_list.Length,
  152. target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
  153. ntargets: target_list.Count,
  154. run_metadata: IntPtr.Zero,
  155. status: status);
  156. status.Check(true);
  157. var result = new NDArray[fetch_list.Length];
  158. for (int i = 0; i < fetch_list.Length; i++)
  159. {
  160. result[i] = fetchValue(output_values[i]);
  161. }
  162. return result;
  163. }
  164. private unsafe NDArray fetchValue(IntPtr output)
  165. {
  166. var tensor = new Tensor(output);
  167. NDArray nd = null;
  168. Type type = tensor.dtype.as_numpy_datatype();
  169. var ndims = tensor.shape.Select(x => (int)x).ToArray();
  170. var offset = c_api.TF_TensorData(output);
  171. switch (tensor.dtype)
  172. {
  173. case TF_DataType.TF_STRING:
  174. var bytes = tensor.Data();
  175. // wired, don't know why we have to start from offset 9.
  176. // length in the begin
  177. var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
  178. nd = np.array(str).reshape();
  179. break;
  180. case TF_DataType.TF_INT16:
  181. var shorts = new short[tensor.size];
  182. for (ulong i = 0; i < tensor.size; i++)
  183. shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
  184. nd = np.array(shorts).reshape(ndims);
  185. break;
  186. case TF_DataType.TF_INT32:
  187. var ints = new int[tensor.size];
  188. for (ulong i = 0; i < tensor.size; i++)
  189. ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
  190. nd = np.array(ints).reshape(ndims);
  191. break;
  192. case TF_DataType.TF_FLOAT:
  193. var floats = new float[tensor.size];
  194. for (ulong i = 0; i < tensor.size; i++)
  195. floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
  196. nd = np.array(floats).reshape(ndims);
  197. break;
  198. case TF_DataType.TF_DOUBLE:
  199. var doubles = new double[tensor.size];
  200. for (ulong i = 0; i < tensor.size; i++)
  201. doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
  202. nd = np.array(doubles).reshape(ndims);
  203. break;
  204. default:
  205. throw new NotImplementedException("can't fetch output");
  206. }
  207. return nd;
  208. }
  209. /// <summary>
  210. /// If a tensor handle that is fed to a device incompatible placeholder,
  211. /// we move the tensor to the right device, generate a new tensor handle,
  212. /// and update feed_dict to use the new handle.
  213. /// </summary>
  214. private List<object> _update_with_movers()
  215. {
  216. return new List<object> { };
  217. }
  218. private void _extend_graph()
  219. {
  220. }
  221. }
  222. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。