You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Graph.cs 23 kB

6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections;
  15. using System.Collections.Generic;
  16. using System.Collections.Specialized;
  17. using System.Linq;
  18. using Tensorflow.Framework;
  19. using Tensorflow.Functions;
  20. using Tensorflow.Common.Extensions;
  21. using Tensorflow.Graphs;
  22. using static Tensorflow.Binding;
  23. namespace Tensorflow
  24. {
  25. /*
  26. A TensorFlow computation, represented as a dataflow graph.
  27. A `Graph` contains a set of
  28. `tf.Operation` objects,
  29. which represent units of computation; and
  30. `tf.Tensor` objects, which represent
  31. the units of data that flow between operations.
  32. A default `Graph` is always registered, and accessible by calling
  33. `tf.get_default_graph`.
  34. To add an operation to the default graph, simply call one of the functions
  35. that defines a new `Operation`:
  36. ```python
  37. c = tf.constant(4.0)
  38. assert c.graph is tf.get_default_graph()
  39. ```
  40. Another typical usage involves the
  41. `tf.Graph.as_default`
  42. context manager, which overrides the current default graph for the
  43. lifetime of the context:
  44. ```python
  45. g = tf.Graph()
  46. with g.as_default():
  47. # Define operations and tensors in `g`.
  48. c = tf.constant(30.0)
  49. assert c.graph is g
  50. ```
  51. Important note: This class *is not* thread-safe for graph construction. All
  52. operations should be created from a single thread, or external
  53. synchronization must be provided. Unless otherwise specified, all methods
  54. are not thread-safe.
  55. A `Graph` instance supports an arbitrary number of "collections"
  56. that are identified by name. For convenience when building a large
  57. graph, collections can store groups of related objects: for
  58. example, the `tf.Variable` uses a collection (named
  59. `tf.GraphKeys.GLOBAL_VARIABLES`) for
  60. all variables that are created during the construction of a graph. The caller
  61. may define additional collections by specifying a new name.
  62. */
  63. /// <summary>
  64. /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
  65. /// This leads to a low-level programming model in which you first define the dataflow graph,
  66. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
  67. /// </summary>
  68. /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
  69. public partial class Graph : IEnumerable<Operation>
  70. {
  71. protected new SafeGraphHandle _handle;
  72. private Dictionary<int, ITensorOrOperation> _nodes_by_id;
  73. public Dictionary<string, ITensorOrOperation> _nodes_by_name;
  74. private Dictionary<string, int> _names_in_use;
  75. public int _version;
  76. private int _next_id_counter;
  77. private List<Operation> _unfetchable_ops = new List<Operation>();
  78. private List<Tensor> _unfeedable_tensors = new List<Tensor>();
  79. private Dictionary<string, EagerDefinedFunction> _functions = new();
  80. internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new();
  81. private VersionDef _graph_def_versions = new VersionDef()
  82. {
  83. Producer = versions.GRAPH_DEF_VERSION,
  84. MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
  85. };
  86. public string _name_stack = "";
  87. protected string _graph_key;
  88. public string graph_key => _graph_key;
  89. public string _last_loss_reduction;
  90. public bool _is_loss_scaled_by_optimizer { get; set; }
  91. /// <summary>
  92. /// True if the graph is considered "finalized". In that case no
  93. /// new operations can be added.
  94. /// </summary>
  95. private bool _finalized = false;
  96. /// <summary>
  97. /// Arbitrary collections of objects.
  98. /// </summary>
  99. private Dictionary<string, object> _collections = new Dictionary<string, object>();
  100. public bool building_function;
  101. string _container = "";
  102. public string Container => _container;
  103. int _seed;
  104. public int seed
  105. {
  106. get => _seed;
  107. set
  108. {
  109. _seed = value;
  110. }
  111. }
  112. internal Graph outer_graph;
  113. public Graph OuterGraph => outer_graph;
  114. public Dictionary<string, EagerDefinedFunction> Functions => _functions;
  115. public SafeGraphHandle c_graph => _handle;
  116. public Graph()
  117. {
  118. _handle = c_api.TF_NewGraph();
  119. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  120. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  121. _names_in_use = new Dictionary<string, int>();
  122. _graph_key = $"graph-{ops.GraphUniqueId()}/";
  123. }
  124. public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
  125. {
  126. return _as_graph_element_locked(obj, allow_tensor, allow_operation);
  127. }
  128. /// <summary>
  129. /// Returns a context manager that makes this `Graph` the default graph.
  130. /// Must call Exit() to pop graph
  131. /// </summary>
  132. /// <returns></returns>
  133. public virtual Graph as_default()
  134. {
  135. tf.Context.graph_mode(isFunc: false);
  136. return ops.set_default_graph(this);
  137. }
  138. public bool IsFunction(string name)
  139. {
  140. return _functions.ContainsKey(tf.compat.as_str(name));
  141. }
  142. internal void AddFunction(EagerDefinedFunction function)
  143. {
  144. _check_not_finalized();
  145. var name = function.Name;
  146. if(function._grad_func_name is not null && function.csharp_grad_func is not null)
  147. {
  148. throw new ValueError($"Gradient defined twice for function {name}");
  149. }
  150. var c_graph = this.c_graph;
  151. var func = function._c_func.Get();
  152. Status status = new();
  153. if (function._grad_func is not null)
  154. {
  155. var gradient = function._grad_func._c_func.Get();
  156. c_api.TF_GraphCopyFunction(c_graph, func, gradient, status);
  157. status.Check(true);
  158. }
  159. else
  160. {
  161. c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status);
  162. status.Check(true);
  163. }
  164. _functions[tf.compat.as_str(name)] = function;
  165. if(_graph_def_versions.MinConsumer < 12)
  166. {
  167. _graph_def_versions.MinConsumer = 12;
  168. }
  169. }
  170. private Tensor _as_graph_element(object obj)
  171. {
  172. if (obj is RefVariable var)
  173. return var._as_graph_element();
  174. else if (obj is ResourceVariable resVar)
  175. return resVar.GraphElement;
  176. return null;
  177. }
  178. private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
  179. {
  180. string types_str = "";
  181. if (allow_tensor && allow_operation)
  182. {
  183. types_str = "Tensor or Operation";
  184. }
  185. else if (allow_tensor)
  186. {
  187. types_str = "Tensor";
  188. }
  189. else if (allow_operation)
  190. {
  191. types_str = "Operation";
  192. }
  193. var temp_obj = _as_graph_element(obj);
  194. if (temp_obj != null)
  195. obj = temp_obj;
  196. // If obj appears to be a name...
  197. if (obj is string name)
  198. {
  199. if (name.Contains(":") && allow_tensor)
  200. {
  201. string op_name = name.Split(':')[0];
  202. int out_n = int.Parse(name.Split(':')[1]);
  203. if (_nodes_by_name.ContainsKey(op_name))
  204. return _nodes_by_name[op_name].outputs[out_n];
  205. else
  206. throw new KeyError($"The name {name} refers to a Tensor which does not " +
  207. $"exist. The operation, {op_name}, does not exist in the " +
  208. "graph.");
  209. }
  210. else if (!name.Contains(":") & allow_operation)
  211. {
  212. if (!_nodes_by_name.ContainsKey(name))
  213. throw new KeyError($"The name {name} refers to an Operation not in the graph.");
  214. return _nodes_by_name[name];
  215. }
  216. else if (!name.Contains(":") & !allow_operation)
  217. {
  218. // Looks like an Operation name but can't be an Operation.
  219. if (_nodes_by_name.ContainsKey(name))
  220. // Yep, it's an Operation name
  221. throw new ValueError($"The name {name} refers to an Operation, not a {types_str}.");
  222. else
  223. throw new ValueError(
  224. $"The name {name} looks like an (invalid) Operation name, not a {types_str}" +
  225. " Tensor names must be of the form \"<op_name>:<output_index>\".");
  226. }
  227. }
  228. if (obj is Tensor tensor && allow_tensor)
  229. {
  230. if (tensor.graph.Equals(this))
  231. {
  232. return tensor;
  233. }
  234. else
  235. {
  236. throw new Exception($"Tensor {obj} is not an element of this graph.");
  237. }
  238. }
  239. else if (obj is Operation op && allow_operation)
  240. {
  241. if (op.graph.Equals(this))
  242. {
  243. return op;
  244. }
  245. else
  246. {
  247. throw new Exception($"Operation {obj} is not an element of this graph.");
  248. }
  249. }
  250. throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
  251. }
  252. public void add_to_collection<T>(string name, T value)
  253. {
  254. _check_not_finalized();
  255. if (_collections.ContainsKey(name))
  256. (_collections[name] as List<T>).Add(value);
  257. else
  258. _collections[name] = new List<T> { value };
  259. }
  260. public void add_to_collections<T>(List<string> names, T value)
  261. {
  262. foreach (string name in names)
  263. add_to_collection(name, value);
  264. }
  265. private void _check_not_finalized()
  266. {
  267. if (_finalized)
  268. throw new RuntimeError("Graph is finalized and cannot be modified.");
  269. }
  270. public virtual Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
  271. TF_DataType[] input_types = null, string name = null,
  272. Dictionary<string, AttrValue> attrs = null, OpDef op_def = null,
  273. bool compute_device = true)
  274. {
  275. if (inputs == null)
  276. inputs = new Tensor[0];
  277. if (string.IsNullOrEmpty(name))
  278. name = op_type;
  279. // If a names ends with a '/' it is a "name scope" and we use it as-is,
  280. // after removing the trailing '/'.
  281. // This was causing duplicate graph node name errors, when testing a conv2d autoencoder
  282. // https://keras.io/guides/functional_api/#:~:text=keras.,graph%20(DAG)%20of%20layers.
  283. // name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
  284. name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
  285. var node_def = ops._NodeDef(op_type, name, attrs: attrs);
  286. var input_ops = inputs.Select(x => x.op).ToArray();
  287. var control_inputs = _control_dependencies_for_inputs(input_ops);
  288. var op = new Operation(node_def,
  289. this,
  290. inputs: inputs,
  291. output_types: dtypes,
  292. control_inputs: control_inputs,
  293. input_types: input_types,
  294. original_op: null,
  295. op_def: op_def);
  296. _create_op_helper(op, compute_device);
  297. return op;
  298. }
  299. public ITensorFlowObject device(string device_name)
  300. {
  301. return new GraphDeviceContext(this, device_name);
  302. }
  303. private void add_device_to_stack(string device_name, int offset = 0)
  304. {
  305. // TODO(Rinne): deal with device spec.
  306. int total_offset = offset + 1;
  307. }
  308. private void _create_op_helper(Operation op, bool compute_device = true)
  309. {
  310. // high priority
  311. // TODO(Rinne): complete the implementation.
  312. op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null);
  313. _record_op_seen_by_control_dependencies(op);
  314. }
  315. public void _add_op(Operation op)
  316. {
  317. op._id_value = _next_id();
  318. _nodes_by_id[op._id] = op;
  319. _nodes_by_name[op.name] = op;
  320. _version = Math.Max(_version, op._id);
  321. }
  322. public int _next_id()
  323. {
  324. return ++_next_id_counter;
  325. }
  326. public bool is_fetchable<T>(T tensor_or_op)
  327. {
  328. if (tensor_or_op is Tensor tensor)
  329. {
  330. return !_unfetchable_ops.Contains(tensor); ;
  331. }
  332. else if (tensor_or_op is Operation op)
  333. {
  334. return !_unfetchable_ops.Contains(op);
  335. }
  336. return false;
  337. }
  338. public string get_name_scope()
  339. {
  340. return _name_stack;
  341. }
  342. public string name_scope(string name)
  343. {
  344. string new_stack = "";
  345. if (string.IsNullOrEmpty(name))
  346. new_stack = "";
  347. else if (name.EndsWith("/"))
  348. new_stack = ops.name_from_scope_name(name);
  349. else
  350. new_stack = unique_name(name);
  351. _name_stack = new_stack;
  352. return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
  353. }
  354. /// <summary>
  355. /// Return a unique operation name for `name`.
  356. ///
  357. /// Note: You rarely need to call `unique_name()` directly.Most of
  358. /// the time you just need to create `with g.name_scope()` blocks to
  359. /// generate structured names.
  360. ///
  361. /// `unique_name` is used to generate structured names, separated by
  362. /// `"/"`, to help identify operations when debugging a graph.
  363. /// Operation names are displayed in error messages reported by the
  364. /// TensorFlow runtime, and in various visualization tools such as
  365. /// TensorBoard.
  366. ///
  367. /// If `mark_as_used` is set to `True`, which is the default, a new
  368. /// unique name is created and marked as in use.If it's set to `False`,
  369. /// the unique name is returned without actually being marked as used.
  370. /// This is useful when the caller simply wants to know what the name
  371. /// to be created will be.
  372. /// </summary>
  373. /// <param name="name">The name for an operation.</param>
  374. /// <param name="mark_as_used"> Whether to mark this name as being used.</param>
  375. /// <returns>A string to be passed to `create_op()` that will be used
  376. /// to name the operation being created.</returns>
  377. public string unique_name(string name, bool mark_as_used = true)
  378. {
  379. if (!String.IsNullOrEmpty(_name_stack))
  380. name = _name_stack + "/" + name;
  381. // For the sake of checking for names in use, we treat names as case
  382. // insensitive (e.g. foo = Foo).
  383. var name_key = name.ToLower();
  384. int i = 0;
  385. if (_names_in_use.ContainsKey(name_key))
  386. i = _names_in_use[name_key];
  387. // Increment the number for "name_key".
  388. if (mark_as_used)
  389. _names_in_use[name_key] = i + 1;
  390. if (i > 0)
  391. {
  392. // Make sure the composed name key is not already used.
  393. var base_name_key = name_key;
  394. while (_names_in_use.ContainsKey(name_key))
  395. {
  396. name_key = $"{base_name_key}_{i}";
  397. i += 1;
  398. }
  399. // Mark the composed name_key as used in case someone wants
  400. // to call unique_name("name_1").
  401. if (mark_as_used)
  402. _names_in_use[name_key] = 1;
  403. // Return the new name with the original capitalization of the given name.
  404. name = $"{name}_{i - 1}";
  405. }
  406. return name;
  407. }
  408. public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results)
  409. {
  410. IntPtr return_output_handle = IntPtr.Zero;
  411. int num_return_outputs = 0;
  412. c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
  413. TF_Output[] return_outputs = new TF_Output[num_return_outputs];
  414. unsafe
  415. {
  416. var tf_output_ptr = (TF_Output*)return_output_handle;
  417. for (int i = 0; i < num_return_outputs; i++)
  418. return_outputs[i] = *(tf_output_ptr + i);
  419. return return_outputs;
  420. }
  421. }
  422. public string[] get_all_collection_keys()
  423. {
  424. return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
  425. }
  426. public object get_collection(string name, string scope = null)
  427. {
  428. return _collections.ContainsKey(name) ? _collections[name] : null;
  429. }
  430. public List<T> get_collection<T>(string name, string scope = null)
  431. {
  432. List<T> t = default;
  433. var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
  434. switch (collection)
  435. {
  436. case List<IVariableV1> list:
  437. t = list.Select(x => (T)(object)x).ToList();
  438. break;
  439. case List<ResourceVariable> list:
  440. t = list.Select(x => (T)(object)x).ToList();
  441. break;
  442. case List<RefVariable> list:
  443. t = list.Select(x => (T)(object)x).ToList();
  444. break;
  445. case List<Tensor> list:
  446. t = list.Select(x => (T)(object)x).ToList();
  447. break;
  448. case List<Operation> list:
  449. t = list.Select(x => (T)(object)x).ToList();
  450. break;
  451. default:
  452. throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
  453. }
  454. return t;
  455. }
  456. public List<T> get_collection_ref<T>(string name)
  457. {
  458. if (!_collections.ContainsKey(name))
  459. _collections[name] = new List<T>();
  460. return _collections[name] as List<T>;
  461. }
  462. public void prevent_feeding(Tensor tensor)
  463. {
  464. _unfeedable_tensors.Add(tensor);
  465. }
  466. public void prevent_fetching(Operation op)
  467. {
  468. _unfetchable_ops.Add(op);
  469. }
  470. public Tensor get_tensor_by_tf_output(TF_Output tf_output)
  471. {
  472. var op = _get_operation_by_tf_operation(tf_output.oper);
  473. return op.outputs[tf_output.index];
  474. }
  475. /// <summary>
  476. /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
  477. /// This method may be called concurrently from multiple threads.
  478. /// </summary>
  479. /// <param name="name">The name of the `Tensor` to return.</param>
  480. /// <exception cref="KeyError">If <paramref name="name"/> does not correspond to a tensor in this graph.</exception>
  481. /// <returns>The `Tensor` with the given <paramref name="name"/>.</returns>
  482. public Tensor get_tensor_by_name(string name)
  483. {
  484. return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false);
  485. }
  486. public Shape GetTensorShape(TF_Output output)
  487. {
  488. var status = tf.Status;
  489. var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status);
  490. status.Check();
  491. if (ndim == -1)
  492. return Shape.Null;
  493. var dims = new long[ndim];
  494. c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status);
  495. status.Check();
  496. return new Shape(dims.Select(x => (int)x).ToArray());
  497. }
  498. public virtual void Exit()
  499. {
  500. tf.Context.restore_mode();
  501. ops.pop_graph();
  502. }
  503. internal EagerDefinedFunction _get_function(string name)
  504. {
  505. return _functions.GetOrDefault(name, null);
  506. }
  507. string debugString = string.Empty;
  508. public override string ToString()
  509. {
  510. return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}";
  511. /*if (string.IsNullOrEmpty(debugString))
  512. {
  513. int len = 0;
  514. debugString = c_api.TF_GraphDebugString(_handle, out len);
  515. }
  516. return debugString;*/
  517. }
  518. private IEnumerable<Operation> GetEnumerable()
  519. => c_api_util.tf_operations(this);
  520. IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
  521. => GetEnumerable().GetEnumerator();
  522. IEnumerator IEnumerable.GetEnumerator()
  523. => throw new NotImplementedException();
  524. public static implicit operator SafeGraphHandle(Graph graph)
  525. {
  526. return graph._handle;
  527. }
  528. }
  529. }