You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Graph.cs 19 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections;
  15. using System.Collections.Generic;
  16. using System.Linq;
  17. using System.Runtime.InteropServices;
  18. namespace Tensorflow
  19. {
  20. /// <summary>
  21. /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
  22. /// This leads to a low-level programming model in which you first define the dataflow graph,
  23. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
  24. /// https://www.tensorflow.org/guide/graphs
  25. /// </summary>
  26. /*
  27. A TensorFlow computation, represented as a dataflow graph.
  28. A `Graph` contains a set of
  29. `tf.Operation` objects,
  30. which represent units of computation; and
  31. `tf.Tensor` objects, which represent
  32. the units of data that flow between operations.
  33. A default `Graph` is always registered, and accessible by calling
  34. `tf.get_default_graph`.
  35. To add an operation to the default graph, simply call one of the functions
  36. that defines a new `Operation`:
  37. ```python
  38. c = tf.constant(4.0)
  39. assert c.graph is tf.get_default_graph()
  40. ```
  41. Another typical usage involves the
  42. `tf.Graph.as_default`
  43. context manager, which overrides the current default graph for the
  44. lifetime of the context:
  45. ```python
  46. g = tf.Graph()
  47. with g.as_default():
  48. # Define operations and tensors in `g`.
  49. c = tf.constant(30.0)
  50. assert c.graph is g
  51. ```
  52. Important note: This class *is not* thread-safe for graph construction. All
  53. operations should be created from a single thread, or external
  54. synchronization must be provided. Unless otherwise specified, all methods
  55. are not thread-safe.
  56. A `Graph` instance supports an arbitrary number of "collections"
  57. that are identified by name. For convenience when building a large
  58. graph, collections can store groups of related objects: for
  59. example, the `tf.Variable` uses a collection (named
  60. `tf.GraphKeys.GLOBAL_VARIABLES`) for
  61. all variables that are created during the construction of a graph. The caller
  62. may define additional collections by specifying a new name.
  63. */
  64. public partial class Graph : DisposableObject, IEnumerable<Operation>
  65. {
  66. private Dictionary<int, ITensorOrOperation> _nodes_by_id;
  67. public Dictionary<string, ITensorOrOperation> _nodes_by_name;
  68. private Dictionary<string, int> _names_in_use;
  69. public int _version;
  70. private int _next_id_counter;
  71. private List<Operation> _unfetchable_ops = new List<Operation>();
  72. private List<Tensor> _unfeedable_tensors = new List<Tensor>();
  73. public string _name_stack = "";
  74. private string _graph_key;
  75. public string graph_key => _graph_key;
  76. public string _last_loss_reduction;
  77. public bool _is_loss_scaled_by_optimizer { get; set; }
  78. /// <summary>
  79. /// True if the graph is considered "finalized". In that case no
  80. /// new operations can be added.
  81. /// </summary>
  82. private bool _finalized = false;
  83. /// <summary>
  84. /// Arbitrary collections of objects.
  85. /// </summary>
  86. private Dictionary<string, object> _collections = new Dictionary<string, object>();
  87. public bool building_function;
  88. public Graph()
  89. {
  90. _handle = c_api.TF_NewGraph();
  91. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  92. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  93. _names_in_use = new Dictionary<string, int>();
  94. _graph_key = $"grap-key-{ops.uid()}/";
  95. }
  96. public Graph(IntPtr handle)
  97. {
  98. _handle = handle;
  99. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  100. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  101. _names_in_use = new Dictionary<string, int>();
  102. _graph_key = $"grap-key-{ops.uid()}/";
  103. }
  104. public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
  105. {
  106. return _as_graph_element_locked(obj, allow_tensor, allow_operation);
  107. }
  108. /// <summary>
  109. /// Returns a context manager that makes this `Graph` the default graph.
  110. /// </summary>
  111. /// <returns></returns>
  112. public Graph as_default()
  113. {
  114. return ops.set_default_graph(this);
  115. }
  116. private Tensor _as_graph_element(object obj)
  117. {
  118. if (obj is RefVariable var)
  119. return var._as_graph_element();
  120. return null;
  121. }
  122. private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
  123. {
  124. string types_str = "";
  125. if (allow_tensor && allow_operation)
  126. {
  127. types_str = "Tensor or Operation";
  128. }
  129. else if (allow_tensor)
  130. {
  131. types_str = "Tensor";
  132. }
  133. else if (allow_operation)
  134. {
  135. types_str = "Operation";
  136. }
  137. var temp_obj = _as_graph_element(obj);
  138. if (temp_obj != null)
  139. obj = temp_obj;
  140. // If obj appears to be a name...
  141. if (obj is string name)
  142. {
  143. if (name.Contains(":") && allow_tensor)
  144. {
  145. string op_name = name.Split(':')[0];
  146. int out_n = int.Parse(name.Split(':')[1]);
  147. if (_nodes_by_name.ContainsKey(op_name))
  148. return _nodes_by_name[op_name].outputs[out_n];
  149. }
  150. else if (!name.Contains(":") & allow_operation)
  151. {
  152. if (!_nodes_by_name.ContainsKey(name))
  153. throw new KeyError($"The name {name} refers to an Operation not in the graph.");
  154. return _nodes_by_name[name];
  155. }
  156. else if (!name.Contains(":") & !allow_operation)
  157. {
  158. // Looks like an Operation name but can't be an Operation.
  159. if (_nodes_by_name.ContainsKey(name))
  160. // Yep, it's an Operation name
  161. throw new ValueError($"The name {name} refers to an Operation, not a {types_str}.");
  162. else
  163. throw new ValueError(
  164. $"The name {name} looks like an (invalid) Operation name, not a {types_str}" +
  165. " Tensor names must be of the form \"<op_name>:<output_index>\".");
  166. }
  167. }
  168. if (obj is Tensor tensor && allow_tensor)
  169. {
  170. if (tensor.graph.Equals(this))
  171. {
  172. return tensor;
  173. }
  174. else
  175. {
  176. throw new Exception($"Tensor {obj} is not an element of this graph.");
  177. }
  178. }
  179. else if (obj is Operation op && allow_operation)
  180. {
  181. if (op.graph.Equals(this))
  182. {
  183. return op;
  184. }
  185. else
  186. {
  187. throw new Exception($"Operation {obj} is not an element of this graph.");
  188. }
  189. }
  190. throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
  191. }
  192. public void add_to_collection<T>(string name, T value)
  193. {
  194. _check_not_finalized();
  195. if (_collections.ContainsKey(name))
  196. (_collections[name] as List<T>).Add(value);
  197. else
  198. _collections[name] = new List<T> { value };
  199. }
  200. public void add_to_collections<T>(List<string> names, T value)
  201. {
  202. foreach (string name in names)
  203. add_to_collection(name, value);
  204. }
  205. private void _check_not_finalized()
  206. {
  207. if (_finalized)
  208. throw new RuntimeError("Graph is finalized and cannot be modified.");
  209. }
  210. public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
  211. TF_DataType[] input_types = null, string name = null,
  212. Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
  213. {
  214. if (inputs == null)
  215. inputs = new Tensor[0];
  216. foreach ((int idx, Tensor a) in Python.enumerate(inputs))
  217. {
  218. }
  219. if (String.IsNullOrEmpty(name))
  220. name = op_type;
  221. // If a names ends with a '/' it is a "name scope" and we use it as-is,
  222. // after removing the trailing '/'.
  223. name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
  224. var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
  225. var input_ops = inputs.Select(x => x.op).ToArray();
  226. var control_inputs = _control_dependencies_for_inputs(input_ops);
  227. var op = new Operation(node_def,
  228. this,
  229. inputs: inputs,
  230. output_types: dtypes,
  231. control_inputs: control_inputs,
  232. input_types: input_types,
  233. original_op: null,
  234. op_def: op_def);
  235. _create_op_helper(op, true);
  236. /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
  237. Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
  238. Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
  239. Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
  240. Console.WriteLine();*/
  241. return op;
  242. }
  243. private void _create_op_helper(Operation op, bool compute_device = true)
  244. {
  245. _record_op_seen_by_control_dependencies(op);
  246. }
  247. public void _add_op(Operation op)
  248. {
  249. op._id_value = _next_id();
  250. _nodes_by_id[op._id] = op;
  251. _nodes_by_name[op.name] = op;
  252. _version = Math.Max(_version, op._id);
  253. }
  254. public int _next_id()
  255. {
  256. return ++_next_id_counter;
  257. }
  258. public bool is_fetchable<T>(T tensor_or_op)
  259. {
  260. if (tensor_or_op is Tensor tensor)
  261. {
  262. return !_unfetchable_ops.Contains(tensor); ;
  263. }
  264. else if (tensor_or_op is Operation op)
  265. {
  266. return !_unfetchable_ops.Contains(op);
  267. }
  268. return false;
  269. }
  270. public string get_name_scope()
  271. {
  272. return _name_stack;
  273. }
  274. public string name_scope(string name)
  275. {
  276. string new_stack = "";
  277. if (string.IsNullOrEmpty(name))
  278. new_stack = "";
  279. else if (name.EndsWith("/"))
  280. new_stack = ops._name_from_scope_name(name);
  281. else
  282. new_stack = unique_name(name);
  283. _name_stack = new_stack;
  284. return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
  285. }
  286. /// <summary>
  287. /// Return a unique operation name for `name`.
  288. ///
  289. /// Note: You rarely need to call `unique_name()` directly.Most of
  290. /// the time you just need to create `with g.name_scope()` blocks to
  291. /// generate structured names.
  292. ///
  293. /// `unique_name` is used to generate structured names, separated by
  294. /// `"/"`, to help identify operations when debugging a graph.
  295. /// Operation names are displayed in error messages reported by the
  296. /// TensorFlow runtime, and in various visualization tools such as
  297. /// TensorBoard.
  298. ///
  299. /// If `mark_as_used` is set to `True`, which is the default, a new
  300. /// unique name is created and marked as in use.If it's set to `False`,
  301. /// the unique name is returned without actually being marked as used.
  302. /// This is useful when the caller simply wants to know what the name
  303. /// to be created will be.
  304. /// </summary>
  305. /// <param name="name">The name for an operation.</param>
  306. /// <param name="mark_as_used"> Whether to mark this name as being used.</param>
  307. /// <returns>A string to be passed to `create_op()` that will be used
  308. /// to name the operation being created.</returns>
  309. public string unique_name(string name, bool mark_as_used = true)
  310. {
  311. if (!String.IsNullOrEmpty(_name_stack))
  312. name = _name_stack + "/" + name;
  313. // For the sake of checking for names in use, we treat names as case
  314. // insensitive (e.g. foo = Foo).
  315. var name_key = name.ToLower();
  316. int i = 0;
  317. if (_names_in_use.ContainsKey(name_key))
  318. i = _names_in_use[name_key];
  319. // Increment the number for "name_key".
  320. if (mark_as_used)
  321. _names_in_use[name_key] = i + 1;
  322. if (i > 0)
  323. {
  324. // Make sure the composed name key is not already used.
  325. var base_name_key = name_key;
  326. while (_names_in_use.ContainsKey(name_key))
  327. {
  328. name_key = $"{base_name_key}_{i}";
  329. i += 1;
  330. }
  331. // Mark the composed name_key as used in case someone wants
  332. // to call unique_name("name_1").
  333. if (mark_as_used)
  334. _names_in_use[name_key] = 1;
  335. // Return the new name with the original capitalization of the given name.
  336. name = $"{name}_{i-1}";
  337. }
  338. return name;
  339. }
  340. public TF_Output[] ReturnOutputs(IntPtr results)
  341. {
  342. IntPtr return_output_handle = IntPtr.Zero;
  343. int num_return_outputs = 0;
  344. c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
  345. TF_Output[] return_outputs = new TF_Output[num_return_outputs];
  346. for (int i = 0; i < num_return_outputs; i++)
  347. {
  348. var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i);
  349. return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
  350. }
  351. return return_outputs;
  352. }
  353. public string[] get_all_collection_keys()
  354. {
  355. return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
  356. }
  357. public object get_collection(string name, string scope = null)
  358. {
  359. return _collections.ContainsKey(name) ? _collections[name] : null;
  360. }
  361. public List<T> get_collection<T>(string name, string scope = null)
  362. {
  363. return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
  364. }
  365. public object get_collection_ref(string name)
  366. {
  367. if (!_collections.ContainsKey(name))
  368. _collections[name] = new List<object>();
  369. return _collections[name];
  370. }
  371. public void prevent_feeding(Tensor tensor)
  372. {
  373. _unfeedable_tensors.Add(tensor);
  374. }
  375. public void prevent_fetching(Operation op)
  376. {
  377. _unfetchable_ops.Add(op);
  378. }
  379. protected override void DisposeManagedState()
  380. {
  381. ops.default_graph_stack.remove(this);
  382. }
  383. protected override void DisposeUnManagedState(IntPtr handle)
  384. {
  385. c_api.TF_DeleteGraph(handle);
  386. }
  387. /// <summary>
  388. /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
  389. /// This method may be called concurrently from multiple threads.
  390. /// </summary>
  391. /// <param name="name">The name of the `Tensor` to return.</param>
  392. /// <exception cref="KeyError">If <paramref name="name"/> does not correspond to a tensor in this graph.</exception>
  393. /// <returns>The `Tensor` with the given <paramref name="name"/>.</returns>
  394. public Tensor get_tensor_by_name(string name)
  395. {
  396. return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false);
  397. }
  398. public TensorShape GetTensorShape(TF_Output output)
  399. {
  400. var status = new Status();
  401. var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status);
  402. status.Check();
  403. if (ndim == -1)
  404. return new TensorShape();
  405. var dims = new long[ndim];
  406. c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status);
  407. status.Check();
  408. return new TensorShape(dims.Select(x => (int)x).ToArray());
  409. }
  410. string debugString = string.Empty;
  411. public override string ToString()
  412. {
  413. return $"{graph_key}, ({_handle})";
  414. /*if (string.IsNullOrEmpty(debugString))
  415. {
  416. int len = 0;
  417. debugString = c_api.TF_GraphDebugString(_handle, out len);
  418. }
  419. return debugString;*/
  420. }
  421. private IEnumerable<Operation> GetEnumerable()
  422. => c_api_util.tf_operations(this);
  423. IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
  424. => GetEnumerable().GetEnumerator();
  425. IEnumerator IEnumerable.GetEnumerator()
  426. {
  427. throw new NotImplementedException();
  428. }
  429. public static implicit operator IntPtr(Graph graph)
  430. {
  431. return graph._handle;
  432. }
  433. }
  434. }