You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSession.cs 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Google.Protobuf;
  14. using Tensorflow.NumPy;
  15. using System;
  16. using System.Collections;
  17. using System.Collections.Generic;
  18. using System.Linq;
  19. using System.Numerics;
  20. using System.Text;
  21. using Tensorflow.Util;
  22. using static Tensorflow.Binding;
  23. namespace Tensorflow
  24. {
  25. public class BaseSession : DisposableObject
  26. {
  27. protected Graph _graph;
  28. public Graph graph => _graph;
  29. public BaseSession(IntPtr handle, Graph g)
  30. {
  31. _handle = handle;
  32. _graph = g ?? ops.get_default_graph();
  33. }
  34. public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
  35. {
  36. _graph = g ?? ops.get_default_graph();
  37. if (!_graph.building_function)
  38. {
  39. if (ops.get_default_graph() != _graph)
  40. _graph.as_default();
  41. }
  42. using var opts = new SessionOptions(target, config);
  43. status = status ?? tf.Status;
  44. _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle);
  45. status.Check(true);
  46. }
  47. public virtual void run(Operation op, params FeedItem[] feed_dict)
  48. {
  49. _run(op, feed_dict);
  50. }
  51. public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
  52. {
  53. return _run(fetche, feed_dict)[0];
  54. }
  55. public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
  56. {
  57. var results = _run(fetche, feed_dict);
  58. return fetche is Tensor ? results[0] : null;
  59. }
  60. public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run(
  61. (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches,
  62. params FeedItem[] feed_dict)
  63. {
  64. var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict);
  65. return (results[0], results[1], results[2], results[3], results[4]);
  66. }
  67. public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
  68. {
  69. var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
  70. return (results[0], results[1], results[2], results[3]);
  71. }
  72. public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
  73. {
  74. var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
  75. return (results[0], results[1], results[2]);
  76. }
  77. public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
  78. {
  79. var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
  80. return (results[0], results[1]);
  81. }
  82. public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
  83. {
  84. return _run(fetches, feed_dict);
  85. }
  86. public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
  87. {
  88. var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
  89. return _run(fetches, feed_items);
  90. }
  91. private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
  92. {
  93. var feed_dict_tensor = new Dictionary<object, object>();
  94. //var feed_map = new Dictionary<object, object>();
  95. // Validate and process feed_dict.
  96. if (feed_dict != null)
  97. {
  98. foreach (var subfeed in feed_dict)
  99. {
  100. var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
  101. //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
  102. feed_dict_tensor[subfeed_t] = subfeed.Value;
  103. //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
  104. }
  105. }
  106. // Create a fetch handler to take care of the structure of fetches.
  107. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
  108. // Run request and get response.
  109. // We need to keep the returned movers alive for the following _do_run().
  110. // These movers are no longer needed when _do_run() completes, and
  111. // are deleted when `movers` goes out of scope when this _run() ends.
  112. var _ = _update_with_movers();
  113. var final_fetches = fetch_handler.fetches();
  114. var final_targets = fetch_handler.targets();
  115. // We only want to really perform the run if fetches or targets are provided,
  116. // or if the call is a partial run that specifies feeds.
  117. var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);
  118. return fetch_handler.build_results(this, results);
  119. }
  120. /// <summary>
  121. /// Runs a step based on the given fetches and feeds.
  122. /// </summary>
  123. /// <param name="target_list">A list of operations to be run, but not fetched.</param>
  124. /// <param name="fetch_list"></param>
  125. /// <param name="feed_dict"></param>
  126. /// <returns>
  127. /// A list of numpy ndarrays, corresponding to the elements of
  128. /// `fetch_list`. If the ith element of `fetch_list` contains the
  129. /// name of an operation, the first Tensor output of that operation
  130. /// will be returned for that element.
  131. /// </returns>
  132. private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
  133. {
  134. var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
  135. int i = 0;
  136. foreach (var x in feed_dict)
  137. {
  138. if (x.Key is Tensor key)
  139. {
  140. switch (x.Value)
  141. {
  142. case Tensor v:
  143. if (v.dtype != key.dtype)
  144. throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
  145. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
  146. break;
  147. case SafeTensorHandle v:
  148. var tensor = new Tensor(v);
  149. if (tensor.dtype != key.dtype)
  150. throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");
  151. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
  152. break;
  153. case bool v:
  154. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  155. break;
  156. case byte v:
  157. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  158. break;
  159. case int v:
  160. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  161. break;
  162. case long v:
  163. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  164. break;
  165. case float v:
  166. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  167. break;
  168. case double v:
  169. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  170. break;
  171. case string v:
  172. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
  173. break;
  174. case Array v:
  175. feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape()));
  176. break;
  177. default:
  178. throw new NotImplementedException("");
  179. }
  180. }
  181. else
  182. throw new NotImplementedException("");
  183. }
  184. var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
  185. //var targets = target_list;
  186. return _call_tf_sessionrun(feeds, fetches, target_list);
  187. }
  188. private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
  189. {
  190. // Ensure any changes to the graph are reflected in the runtime.
  191. _extend_graph();
  192. var status = tf.Status;
  193. var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
  194. c_api.TF_SessionRun(_handle,
  195. run_options: null,
  196. inputs: feed_dict.Select(f => f.Key).ToArray(),
  197. input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(),
  198. ninputs: feed_dict.Length,
  199. outputs: fetch_list,
  200. output_values: output_values,
  201. noutputs: fetch_list.Length,
  202. target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
  203. ntargets: target_list.Count,
  204. run_metadata: IntPtr.Zero,
  205. status: status.Handle);
  206. status.Check(true);
  207. var result = new NDArray[fetch_list.Length];
  208. for (int i = 0; i < fetch_list.Length; i++)
  209. result[i] = fetchValue(new SafeTensorHandle(output_values[i]));
  210. return result;
  211. }
  212. public unsafe Tensor eval(Tensor tensor)
  213. {
  214. var status = tf.Status;
  215. var output_values = new IntPtr[1];
  216. var fetch_list = new[] { tensor._as_tf_output() };
  217. c_api.TF_SessionRun(_handle,
  218. run_options: null,
  219. inputs: new TF_Output[0],
  220. input_values: new IntPtr[0],
  221. ninputs: 0,
  222. outputs: fetch_list,
  223. output_values: output_values,
  224. noutputs: 1,
  225. target_opers: new IntPtr[0],
  226. ntargets: 0,
  227. run_metadata: IntPtr.Zero,
  228. status: status.Handle);
  229. status.Check(true);
  230. return new Tensor(new SafeTensorHandle(output_values[0]));
  231. }
  232. private static unsafe NDArray fetchValue(SafeTensorHandle output)
  233. {
  234. var tensor = new Tensor(output);
  235. return tensor.numpy();
  236. }
  237. /// <summary>
  238. /// If a tensor handle that is fed to a device incompatible placeholder,
  239. /// we move the tensor to the right device, generate a new tensor handle,
  240. /// and update feed_dict to use the new handle.
  241. /// </summary>
  242. private List<object> _update_with_movers()
  243. {
  244. return new List<object> { };
  245. }
  246. private void _extend_graph()
  247. { }
  248. protected override void DisposeUnmanagedResources(IntPtr handle)
  249. {
  250. // c_api.TF_CloseSession(handle, tf.Status.Handle);
  251. c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
  252. }
  253. }
  254. }