You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.py.cs 21 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using System.Threading;
  6. using Tensorflow;
  7. using Google.Protobuf;
  8. using System.Linq;
  9. using NumSharp;
  10. using System.ComponentModel;
  11. using static Tensorflow.Python;
  12. namespace Tensorflow
  13. {
  14. public partial class ops
  15. {
  16. public static void add_to_collection<T>(string name, T value)
  17. {
  18. var graph = tf.get_default_graph();
  19. graph.add_to_collection(name, value);
  20. }
  21. public static void add_to_collections<T>(List<string> names, T value)
  22. {
  23. var graph = tf.get_default_graph();
  24. graph.add_to_collections(names, value);
  25. }
  26. /// <summary>
  27. /// Wrapper for `Graph.get_collection()` using the default graph.
  28. /// contains many standard names for collections.
  29. /// </summary>
  30. /// <param name="key">
  31. /// The key for the collection. For example, the `GraphKeys` class
  32. /// </param>
  33. /// <param name="scope"></param>
  34. /// <returns>
  35. /// The list of values in the collection with the given `name`, or
  36. /// an empty list if no value has been added to that collection. The
  37. /// list contains the values in the order under which they were
  38. /// collected.
  39. /// </returns>
  40. public static object get_collection(string key, string scope = null)
  41. {
  42. return get_default_graph().get_collection(key, scope);
  43. }
  44. public static object get_collection_ref(string key)
  45. {
  46. return get_default_graph().get_collection_ref(key);
  47. }
  48. public static DefaultGraphStack default_graph_stack = new DefaultGraphStack();
  49. /// <summary>
  50. /// Returns the default graph for the current thread.
  51. ///
  52. /// The returned graph will be the innermost graph on which a
  53. /// `Graph.as_default()` context has been entered, or a global default
  54. /// graph if none has been explicitly created.
  55. ///
  56. /// NOTE: The default graph is a property of the current thread.If you
  57. /// create a new thread, and wish to use the default graph in that
  58. /// thread, you must explicitly add a `with g.as_default():` in that
  59. /// thread's function.
  60. /// </summary>
  61. /// <returns></returns>
  62. public static Graph get_default_graph()
  63. {
  64. //TODO: original source indicates there should be a _default_graph_stack!
  65. //return _default_graph_stack.get_default()
  66. return default_graph_stack.get_controller();
  67. }
  68. public static Graph set_default_graph(Graph graph)
  69. {
  70. //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
  71. default_graph_stack.set_controller(graph);
  72. return default_graph_stack.get_controller();
  73. }
  74. /// <summary>
  75. /// Clears the default graph stack and resets the global default graph.
  76. ///
  77. /// NOTE: The default graph is a property of the current thread.This
  78. /// function applies only to the current thread.Calling this function while
  79. /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
  80. /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
  81. /// after calling this function will result in undefined behavior.
  82. /// </summary>
  83. /// <returns></returns>
  84. public static void reset_default_graph()
  85. {
  86. //TODO: original source indicates there should be a _default_graph_stack!
  87. //if (!_default_graph_stack.is_cleared())
  88. // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
  89. // "nested graphs. If you need a cleared graph, " +
  90. // "exit the nesting and create a new graph.");
  91. default_graph_stack.reset();
  92. }
  93. public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
  94. => _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
  95. public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null)
  96. {
  97. foreach(var op_input in op_input_list)
  98. {
  99. // Determine if this is a valid graph_element.
  100. var graph_element = op_input;
  101. }
  102. return get_default_graph();
  103. }
  104. /// <summary>
  105. /// Converts the given `value` to a `Tensor`.
  106. /// </summary>
  107. /// <param name="value"></param>
  108. /// <param name="dtype"></param>
  109. /// <param name="name"></param>
  110. /// <returns></returns>
  111. public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
  112. {
  113. return convert_to_tensor_v2(value, dtype, preferred_dtype, name);
  114. }
  115. public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = null)
  116. {
  117. return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false);
  118. }
  119. public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  120. {
  121. return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false);
  122. }
  123. public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  124. {
  125. return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
  126. }
  127. /// <summary>
  128. /// Wrapper for `Graph.control_dependencies()` using the default graph.
  129. ///
  130. /// See `tf.Graph.control_dependencies` for more details.
  131. /// When eager execution is enabled, any callable object in the `control_inputs`
  132. /// list will be called.
  133. /// </summary>
  134. /// <param name="control_inputs">
  135. /// A list of `Operation` or `Tensor` objects which
  136. /// must be executed or computed before running the operations
  137. /// defined in the context.Can also be `None` to clear the control
  138. /// dependencies.If eager execution is enabled, any callable object in the
  139. /// `control_inputs` list will be called.
  140. /// </param>
  141. /// <returns>
  142. /// A context manager that specifies control dependencies for all
  143. /// operations constructed within the context.
  144. /// </returns>
  145. public static _ControlDependenciesController control_dependencies(object[] control_inputs)
  146. {
  147. return get_default_graph().control_dependencies(control_inputs);
  148. }
  149. public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
  150. => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());
  151. /// <summary>
  152. /// Creates a TF_Operation.
  153. /// </summary>
  154. /// <param name="graph">a `Graph`.</param>
  155. /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
  156. /// <param name="inputs">
  157. /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
  158. /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
  159. /// "list(int64)"). The length of the list should be equal to the number of
  160. /// inputs specified by this operation's op def.
  161. /// </param>
  162. /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
  163. /// <returns>A wrapped TF_Operation*.</returns>
  164. public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
  165. {
  166. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  167. // Add inputs
  168. foreach (var op_input in inputs)
  169. {
  170. if (op_input is Tensor[] op_inputs)
  171. c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
  172. else if (op_input is Tensor op_input1)
  173. {
  174. if (op_input1.op == null)
  175. c_api.TF_AddInput(op_desc, new TF_Output(op_desc, 0));
  176. else
  177. c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
  178. }
  179. else
  180. throw new NotImplementedException("_create_c_op");
  181. }
  182. var status = new Status();
  183. // Add control inputs
  184. foreach (var control_input in control_inputs)
  185. c_api.TF_AddControlInput(op_desc, control_input);
  186. // Add attrs
  187. foreach (var attr in node_def.Attr)
  188. {
  189. var bytes = attr.Value.ToByteArray();
  190. var proto = Marshal.AllocHGlobal(bytes.Length);
  191. Marshal.Copy(bytes, 0, proto, bytes.Length);
  192. uint len = (uint)bytes.Length;
  193. c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);
  194. status.Check(true);
  195. }
  196. var c_op = c_api.TF_FinishOperation(op_desc, status);
  197. status.Check(true);
  198. return (c_op, op_desc);
  199. }
  200. public static OpDef _get_op_def(Graph graph, string type)
  201. {
  202. return graph.GetOpDef(type);
  203. }
  204. public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
  205. {
  206. var node_def = new NodeDef();
  207. node_def.Op = op_type;
  208. node_def.Name = name;
  209. if (attrs != null)
  210. {
  211. foreach (var attr in attrs)
  212. node_def.Attr.Add(attr.Key, attr.Value);
  213. }
  214. return node_def;
  215. }
  216. public static string _name_from_scope_name(string name)
  217. {
  218. if (name.EndsWith("/"))
  219. {
  220. return name.Substring(0, name.Length - 1);
  221. }
  222. else
  223. {
  224. return name;
  225. }
  226. }
  227. /// <summary>
  228. /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
  229. /// </summary>
  230. /// <returns></returns>
  231. public static void init_scope()
  232. {
  233. // Retrieve the active name scope: entering an `init_scope` preserves
  234. // the name scope of the current context.
  235. var default_graph = get_default_graph();
  236. var scope = default_graph.get_name_scope();
  237. if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
  238. // Names that end with trailing slashes are treated by `name_scope` as
  239. // absolute.
  240. scope += "/";
  241. // inner_device_stack = default_graph._device_function_stack
  242. // var outer_context = default_graph.as_default;
  243. with(ops.control_dependencies(null), delegate
  244. {
  245. var outer_graph = get_default_graph();
  246. // outer_device_stack = None
  247. });
  248. }
  249. private static int uid_number = 0;
  250. /// <summary>
  251. /// A unique (within this program execution) integer.
  252. /// Not thread safe
  253. /// </summary>
  254. /// <returns></returns>
  255. public static int uid()
  256. {
  257. return uid_number++;
  258. }
  259. public static void colocate_with(Operation op, bool ignore_existing = false)
  260. {
  261. _colocate_with_for_gradient(op, null, ignore_existing);
  262. }
  263. public static void colocate_with(Tensor tensor, bool ignore_existing = false)
  264. {
  265. _colocate_with_for_gradient(tensor.op, null, ignore_existing);
  266. }
  267. public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false)
  268. {
  269. var default_graph = get_default_graph();
  270. default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
  271. }
  272. /// <summary>
  273. /// Uses the default session to evaluate one or more tensors.
  274. /// </summary>
  275. /// <param name="tensors">A single Tensor, or a list of Tensor objects.</param>
  276. /// <param name="feed_dict">
  277. /// A dictionary that maps Tensor objects (or tensor names) to lists,
  278. /// numpy ndarrays, TensorProtos, or strings.
  279. /// </param>
  280. /// <param name="graph">The graph in which the tensors are defined.</param>
  281. /// <param name="session">A different session to use to evaluate "tensors".</param>
  282. /// <returns>
  283. /// Either a single numpy ndarray if "tensors" is a single tensor; or a list
  284. /// of numpy ndarrays that each correspond to the respective element in
  285. /// "tensors".
  286. /// </returns>
  287. public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
  288. {
  289. if (session == null)
  290. {
  291. session = get_default_session();
  292. if (session == null)
  293. throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
  294. "session is registered. Use `with " +
  295. "sess.as_default()` or pass an explicit session to " +
  296. "`eval(session=sess)`");
  297. if (session.graph != graph)
  298. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  299. "the tensor's graph is different from the session's " +
  300. "graph. Pass an explicit session to " +
  301. "`eval(session=sess)`.");
  302. }
  303. else
  304. {
  305. if (session.graph != graph)
  306. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  307. "the tensor's graph is different from the session's " +
  308. "graph. Pass an explicit session to " +
  309. "`eval(session=sess)`.");
  310. }
  311. return session.run(tensor, feed_dict);
  312. }
  313. /// <summary>
  314. /// Returns the default session for the current thread.
  315. /// </summary>
  316. /// <returns>The default `Session` being used in the current thread.</returns>
  317. public static Session get_default_session()
  318. {
  319. if (tf.defaultSession == null)
  320. tf.defaultSession = tf.Session();
  321. return tf.defaultSession;
  322. }
  323. /// <summary>
  324. /// Prepends name scope to a name.
  325. /// </summary>
  326. /// <param name="name"></param>
  327. /// <param name="import_scope"></param>
  328. /// <returns></returns>
  329. public static string prepend_name_scope(string name, string import_scope)
  330. {
  331. if (!string.IsNullOrEmpty(import_scope))
  332. {
  333. if (import_scope.EndsWith("/"))
  334. import_scope = import_scope.Substring(0, import_scope.Length - 1);
  335. return $"{import_scope}/{name}";
  336. }
  337. else
  338. return name;
  339. }
  340. public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
  341. {
  342. if (session == null)
  343. {
  344. session = get_default_session();
  345. if (session == null)
  346. throw new ValueError("Cannot execute operation using `run()`: No default " +
  347. "session is registered. Use `with " +
  348. "sess.as_default():` or pass an explicit session to " +
  349. "`run(session=sess)`");
  350. }
  351. if (session.graph != graph)
  352. throw new ValueError("Cannot use the default session to execute operation: " +
  353. "the operation's graph is different from the " +
  354. "session's graph. Pass an explicit session to " +
  355. "run(session=sess).");
  356. session.run(operation, feed_dict);
  357. }
  358. public static Tensor[] convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  359. => internal_convert_n_to_tensor(values, dtype: dtype, name: name, as_ref: false);
  360. public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  361. => internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
  362. public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  363. => internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
  364. public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  365. => value;
  366. public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  367. {
  368. var ret = new List<Tensor>();
  369. foreach(var (i, value) in Python.enumerate(values))
  370. {
  371. if (value == null)
  372. {
  373. ret.Add(value);
  374. }
  375. else
  376. {
  377. var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  378. ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
  379. }
  380. }
  381. return ret.ToArray();
  382. }
  383. public static Tensor[] internal_convert_n_to_tensor(object values, TF_DataType dtype = TF_DataType.DtInvalid,
  384. string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
  385. bool as_ref = false)
  386. {
  387. var ret = new List<Tensor>();
  388. foreach((int i, object value) in enumerate(values as object[]))
  389. {
  390. string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  391. ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
  392. }
  393. return ret.ToArray();
  394. }
  395. public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid,
  396. string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
  397. bool as_ref = false,
  398. string scope = null)
  399. {
  400. if (dtype == TF_DataType.DtInvalid)
  401. dtype = preferred_dtype;
  402. switch (value)
  403. {
  404. case NDArray nd:
  405. return constant_op.constant(nd, dtype: dtype, name: name);
  406. case Tensor tensor:
  407. return tensor;
  408. case Tensor[] tensors:
  409. return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name);
  410. case RefVariable varVal:
  411. return varVal._TensorConversionFunction(as_ref: as_ref);
  412. case ResourceVariable varVal:
  413. return null;
  414. case object[] objects:
  415. return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
  416. default:
  417. return constant_op.constant(value, dtype: dtype, name: name);
  418. }
  419. }
  420. public static string strip_name_scope(string name, string export_scope = "")
  421. {
  422. if (!string.IsNullOrEmpty(export_scope))
  423. {
  424. throw new NotImplementedException("ops.strip_name_scope");
  425. }
  426. else
  427. {
  428. return name;
  429. }
  430. }
  431. public static string get_name_scope()
  432. {
  433. var g = get_default_graph();
  434. return g.get_name_scope();
  435. }
  436. }
  437. }