You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cs 24 kB

4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
6 years ago
4 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Google.Protobuf;
  14. using Google.Protobuf.Collections;
  15. using Tensorflow.NumPy;
  16. using System;
  17. using System.Collections.Generic;
  18. using System.Linq;
  19. using System.Threading;
  20. using Tensorflow.Contexts;
  21. using Tensorflow.Eager;
  22. using Tensorflow.Graphs;
  23. using Tensorflow.Util;
  24. using static Tensorflow.Binding;
  25. using static Tensorflow.CppShapeInferenceResult.Types;
  26. namespace Tensorflow
  27. {
  28. public partial class ops
  29. {
  30. public static long tensor_id(Tensor tensor)
  31. {
  32. return tensor.Id;
  33. }
  34. public static void add_to_collection<T>(string name, T value)
  35. {
  36. var graph = tf.get_default_graph();
  37. graph.add_to_collection(name, value);
  38. }
  39. public static void add_to_collections<T>(List<string> names, T value)
  40. {
  41. var graph = tf.get_default_graph();
  42. graph.add_to_collections(names, value);
  43. }
  44. /// <summary>
  45. /// Wrapper for `Graph.get_collection()` using the default graph.
  46. /// contains many standard names for collections.
  47. /// </summary>
  48. /// <param name="key">
  49. /// The key for the collection. For example, the `GraphKeys` class
  50. /// </param>
  51. /// <param name="scope"></param>
  52. /// <returns>
  53. /// The list of values in the collection with the given `name`, or
  54. /// an empty list if no value has been added to that collection. The
  55. /// list contains the values in the order under which they were
  56. /// collected.
  57. /// </returns>
  58. public static object get_collection(string key, string scope = null)
  59. {
  60. return get_default_graph().get_collection(key, scope);
  61. }
  62. public static List<T> get_collection<T>(string key, string scope = null)
  63. {
  64. return get_default_graph().get_collection<T>(key, scope);
  65. }
  66. public static List<T> get_collection_ref<T>(string key)
  67. {
  68. return get_default_graph().get_collection_ref<T>(key);
  69. }
  70. public static Graph _get_graph_from_inputs(params object[] op_input_list)
  71. {
  72. var current_default_graph = get_default_graph();
  73. if (current_default_graph.building_function)
  74. return current_default_graph;
  75. Graph graph = null;
  76. foreach (var op_input in op_input_list)
  77. {
  78. if (op_input is Tensor op_input_tensor)
  79. graph = graph ?? op_input_tensor.graph;
  80. }
  81. return graph ?? current_default_graph;
  82. }
  83. public static Graph _get_graph_from_inputs(Tensors op_input_list)
  84. => _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
  85. public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null)
  86. {
  87. foreach (var op_input in op_input_list)
  88. {
  89. // Determine if this is a valid graph_element.
  90. // var graph_element = op_input;
  91. }
  92. return get_default_graph();
  93. }
  94. /// <summary>
  95. /// Converts the given `value` to a `Tensor`.
  96. /// </summary>
  97. /// <param name="value"></param>
  98. /// <param name="dtype"></param>
  99. /// <param name="name"></param>
  100. /// <returns></returns>
  101. public static Tensor convert_to_tensor(object value,
  102. TF_DataType dtype = TF_DataType.DtInvalid,
  103. string name = null,
  104. bool as_ref = false,
  105. TF_DataType preferred_dtype = TF_DataType.DtInvalid,
  106. Context ctx = null)
  107. {
  108. if (dtype == TF_DataType.DtInvalid)
  109. dtype = preferred_dtype;
  110. if (dtype == TF_DataType.DtInvalid)
  111. dtype = value.GetDataType();
  112. if (value is EagerTensor eager_tensor)
  113. {
  114. if (tf.executing_eagerly())
  115. {
  116. if (dtype != TF_DataType.DtInvalid && dtype != eager_tensor.dtype)
  117. return gen_math_ops.cast(eager_tensor, dtype.as_base_dtype(), name: name);
  118. return eager_tensor;
  119. }
  120. else
  121. {
  122. var graph = get_default_graph();
  123. if (!graph.building_function)
  124. throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
  125. return (graph as FuncGraph).capture(eager_tensor, name: name);
  126. }
  127. }
  128. // graph mode
  129. Tensor ret = value switch
  130. {
  131. NDArray nd => constant_op.constant(nd, dtype: dtype, name: name),
  132. EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE
  133. ? tensor.AsPlaceholder(name: name)
  134. : tensor.AsConstant(name: name),
  135. Tensor tensor => tensor,
  136. IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
  137. RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
  138. ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
  139. Axis ts => constant_op.constant(ts, dtype: dtype, name: name),
  140. Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
  141. string str => constant_op.constant(str, dtype: tf.@string, name: name),
  142. string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
  143. IEnumerable<object> objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name),
  144. _ => constant_op.constant(value, dtype: dtype, name: name)
  145. };
  146. if (dtype == TF_DataType.TF_STRING)
  147. return ret;
  148. if (dtype != TF_DataType.DtInvalid && dtype.as_base_dtype() != ret.dtype.as_base_dtype())
  149. ret = gen_math_ops.cast(ret, dtype, name: name);
  150. return ret;
  151. }
  152. public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  153. {
  154. return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false);
  155. }
  156. public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  157. => convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
  158. /// <summary>
  159. /// Wrapper for `Graph.control_dependencies()` using the default graph.
  160. ///
  161. /// See `tf.Graph.control_dependencies` for more details.
  162. ///
  163. /// When eager execution is enabled, any callable object in the `control_inputs`
  164. /// list will be called.
  165. /// </summary>
  166. /// <param name="control_inputs">
  167. /// A list of `Operation` or `Tensor` objects which
  168. /// must be executed or computed before running the operations
  169. /// defined in the context.Can also be `None` to clear the control
  170. /// dependencies.If eager execution is enabled, any callable object in the
  171. /// `control_inputs` list will be called.
  172. /// </param>
  173. /// <returns>
  174. /// A context manager that specifies control dependencies for all
  175. /// operations constructed within the context.
  176. /// </returns>
  177. public static _ControlDependenciesController control_dependencies(object[] control_inputs)
  178. => get_default_graph().control_dependencies(control_inputs);
  179. /// <summary>
  180. /// Creates a TF_Operation.
  181. /// </summary>
  182. /// <param name="graph">a `Graph`.</param>
  183. /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
  184. /// <param name="inputs">
  185. /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
  186. /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
  187. /// "list(int64)"). The length of the list should be equal to the number of
  188. /// inputs specified by this operation's op def.
  189. /// </param>
  190. /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
  191. /// <returns>A wrapped TF_Operation*.</returns>
  192. public static (IntPtr, OperationDescription) _create_c_op(Graph graph, NodeDef node_def, Tensor[] inputs, Operation[] control_inputs,
  193. OpDef op_def = null)
  194. {
  195. if (op_def == null)
  196. op_def = graph.GetOpDef(node_def.Op);
  197. var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
  198. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  199. if (!string.IsNullOrEmpty(node_def.Device))
  200. c_api.TF_SetDevice(op_desc, node_def.Device);
  201. // Add inputs
  202. foreach (var op_input in input_tensors)
  203. {
  204. if (op_input.IsList)
  205. c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count());
  206. else if (op_input.Count() == 1)
  207. c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output());
  208. }
  209. var status = tf.Status;
  210. // Add control inputs
  211. foreach (var control_input in control_inputs)
  212. c_api.TF_AddControlInput(op_desc, control_input);
  213. // Add attrs
  214. foreach (var attr in node_def.Attr)
  215. {
  216. var bytes = attr.Value.ToByteArray();
  217. c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status);
  218. status.Check(true);
  219. }
  220. var c_op = op_desc.FinishOperation(status);
  221. status.Check(true);
  222. return (c_op, op_desc);
  223. }
  224. public static Tensors[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
  225. {
  226. var grouped_inputs = new List<Tensors>();
  227. int i = 0;
  228. foreach (var input_arg in op_def.InputArg)
  229. {
  230. int input_len = 1;
  231. bool is_sequence = false;
  232. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  233. {
  234. input_len = (int)attrs[input_arg.NumberAttr].I;
  235. is_sequence = true;
  236. }
  237. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  238. {
  239. input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
  240. is_sequence = true;
  241. }
  242. if (is_sequence)
  243. {
  244. var input_tensors = new Tensors(inputs.Skip(i).Take(input_len).ToArray());
  245. input_tensors.IsList = true;
  246. grouped_inputs.Add(input_tensors);
  247. }
  248. else
  249. grouped_inputs.Add(inputs[i]);
  250. i += input_len;
  251. }
  252. return grouped_inputs.ToArray();
  253. }
  254. public static OpDef _get_op_def(Graph graph, string type)
  255. {
  256. return graph.GetOpDef(type);
  257. }
  258. public static NodeDef _NodeDef(string op_type, string name, Dictionary<string, AttrValue> attrs = null)
  259. {
  260. var node_def = new NodeDef();
  261. node_def.Op = op_type;
  262. node_def.Name = name;
  263. if (attrs != null)
  264. {
  265. foreach (var attr in attrs)
  266. node_def.Attr.Add(attr.Key, attr.Value);
  267. }
  268. return node_def;
  269. }
  270. public static string name_from_scope_name(string name)
  271. {
  272. if (name == null)
  273. return null;
  274. else if (name.EndsWith("/"))
  275. return name.Substring(0, name.Length - 1);
  276. else
  277. return name;
  278. }
  279. /// <summary>
  280. /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
  281. /// </summary>
  282. /// <returns></returns>
  283. public static NameScope init_scope()
  284. {
  285. // Retrieve the active name scope: entering an `init_scope` preserves
  286. // the name scope of the current context.
  287. var default_graph = get_default_graph();
  288. var scope = default_graph.get_name_scope();
  289. if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
  290. // Names that end with trailing slashes are treated by `name_scope` as
  291. // absolute.
  292. scope += "/";
  293. // inner_device_stack = default_graph._device_function_stack
  294. // var outer_context = default_graph.as_default;
  295. tf_with(ops.control_dependencies(null), delegate
  296. {
  297. // var outer_graph = get_default_graph();
  298. // outer_device_stack = None
  299. });
  300. tf.Context.ScopeName = scope;
  301. return ops.name_scope(scope);
  302. }
  303. private static int uid_number = -1;
  304. /// <summary>
  305. /// A unique (within this program execution) integer.
  306. /// Not thread safe
  307. /// </summary>
  308. /// <returns></returns>
  309. public static int uid()
  310. {
  311. return Interlocked.Increment(ref uid_number);
  312. }
  313. static int graph_uid_number = -1;
  314. public static int GraphUniqueId()
  315. {
  316. return Interlocked.Increment(ref graph_uid_number);
  317. }
  318. static int uid_number_for_function = 0;
  319. public static int uid_function()
  320. => Interlocked.Increment(ref uid_number_for_function);
  321. static int uid_number_for_layer = 0;
  322. public static int uid_layer()
  323. => Interlocked.Increment(ref uid_number_for_layer);
  324. public static void reset_uid()
  325. {
  326. uid_number = -1;
  327. graph_uid_number = -1;
  328. uid_number_for_function = 0;
  329. uid_number_for_layer = 0;
  330. }
  331. public static void colocate_with(bool ignore_existing = false)
  332. {
  333. _colocate_with_for_gradient(null, null, ignore_existing);
  334. }
  335. public static void colocate_with(Operation op, bool ignore_existing = false)
  336. {
  337. _colocate_with_for_gradient(op, null, ignore_existing);
  338. }
  339. public static void colocate_with(Tensor tensor, bool ignore_existing = false)
  340. {
  341. _colocate_with_for_gradient(tensor.op, null, ignore_existing);
  342. }
  343. public static void colocate_with(IVariableV1 variable, bool ignore_existing = false)
  344. {
  345. _colocate_with_for_gradient(variable.AsTensor(), null, ignore_existing);
  346. }
  347. public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false)
  348. {
  349. var default_graph = get_default_graph();
  350. default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
  351. }
  352. /// <summary>
  353. /// Uses the default session to evaluate one or more tensors.
  354. /// </summary>
  355. /// <param name="tensor">A single Tensor, or a list of Tensor objects.</param>
  356. /// <param name="feed_dict">
  357. /// A dictionary that maps Tensor objects (or tensor names) to lists,
  358. /// numpy ndarrays, TensorProtos, or strings.
  359. /// </param>
  360. /// <param name="graph">The graph in which the tensors are defined.</param>
  361. /// <param name="session">A different session to use to evaluate "tensors".</param>
  362. /// <returns>
  363. /// Either a single numpy ndarray if "tensors" is a single tensor; or a list
  364. /// of numpy ndarrays that each correspond to the respective element in
  365. /// "tensors".
  366. /// </returns>
  367. public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
  368. {
  369. if (session == null)
  370. {
  371. session = get_default_session();
  372. if (session == null)
  373. throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
  374. "session is registered. Use `with " +
  375. "sess.as_default()` or pass an explicit session to " +
  376. "`eval(session=sess)`");
  377. if (session.graph != graph)
  378. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  379. "the tensor's graph is different from the session's " +
  380. "graph. Pass an explicit session to " +
  381. "`eval(session=sess)`.");
  382. }
  383. else
  384. {
  385. if (session.graph != graph)
  386. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  387. "the tensor's graph is different from the session's " +
  388. "graph. Pass an explicit session to " +
  389. "`eval(session=sess)`.");
  390. }
  391. return session.run(tensor, feed_dict);
  392. }
  393. /// <summary>
  394. /// Prepends name scope to a name.
  395. /// </summary>
  396. /// <param name="name"></param>
  397. /// <param name="import_scope"></param>
  398. /// <returns></returns>
  399. public static string prepend_name_scope(string name, string import_scope)
  400. {
  401. if (!string.IsNullOrEmpty(import_scope))
  402. {
  403. if (import_scope.EndsWith("/"))
  404. import_scope = import_scope.Substring(0, import_scope.Length - 1);
  405. return $"{import_scope}/{name}";
  406. }
  407. else
  408. return name;
  409. }
  410. public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
  411. {
  412. if (session == null)
  413. {
  414. session = get_default_session();
  415. if (session == null)
  416. throw new ValueError("Cannot execute operation using `run()`: No default " +
  417. "session is registered. Use `with " +
  418. "sess.as_default():` or pass an explicit session to " +
  419. "`run(session=sess)`");
  420. }
  421. if (session.graph != graph)
  422. throw new ValueError("Cannot use the default session to execute operation: " +
  423. "the operation's graph is different from the " +
  424. "session's graph. Pass an explicit session to " +
  425. "run(session=sess).");
  426. session.run(operation, feed_dict);
  427. }
  428. public static Tensor[] convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  429. => internal_convert_n_to_tensor(values, dtype: dtype, name: name, as_ref: false);
  430. public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  431. => internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
  432. public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
  433. => internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
  434. public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  435. => value;
  436. public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  437. {
  438. var ret = new List<Tensor>();
  439. foreach (var (i, value) in enumerate(values))
  440. {
  441. if (value == null)
  442. {
  443. ret.Add(value);
  444. }
  445. else
  446. {
  447. var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  448. ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
  449. }
  450. }
  451. return ret.ToArray();
  452. }
  453. public static Tensor[] internal_convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid,
  454. string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
  455. bool as_ref = false)
  456. {
  457. var ret = new List<Tensor>();
  458. foreach ((int i, object value) in enumerate(values))
  459. {
  460. string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  461. ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
  462. }
  463. return ret.ToArray();
  464. }
  465. public static string strip_name_scope(string name, string export_scope = "")
  466. {
  467. if (!string.IsNullOrEmpty(export_scope))
  468. {
  469. throw new NotImplementedException("ops.strip_name_scope");
  470. }
  471. else
  472. {
  473. return name;
  474. }
  475. }
  476. public static string get_name_scope()
  477. {
  478. var g = get_default_graph();
  479. return g.get_name_scope();
  480. }
  481. public static bool executing_eagerly_outside_functions()
  482. {
  483. if (tf.Context.executing_eagerly())
  484. return true;
  485. else
  486. throw new NotImplementedException("");
  487. }
  488. public static bool inside_function()
  489. {
  490. return get_default_graph().building_function;
  491. }
  492. public static HandleData get_resource_handle_data(Tensor graph_op)
  493. {
  494. var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
  495. var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data));
  496. return HandleData.Parser.ParseFrom(handle_str);
  497. }
  498. public static void dismantle_graph(Graph graph)
  499. {
  500. }
  501. public static ITensorFlowObject device(string device_name)
  502. {
  503. if (tf.Context.executing_eagerly())
  504. {
  505. return tf.Context.device(device_name);
  506. }
  507. //else if (ops.executing_eagerly_outside_functions())
  508. //{
  509. // throw new NotImplementedException();
  510. //}
  511. else
  512. {
  513. return get_default_graph().device(device_name);
  514. }
  515. // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`.
  516. }
  517. public class NullContextManager: IDisposable
  518. {
  519. public void Dispose()
  520. {
  521. }
  522. }
  523. }
  524. }