You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Graph.cs 20 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections;
  15. using System.Collections.Generic;
  16. using System.Linq;
  17. using System.Runtime.InteropServices;
  18. using static Tensorflow.Binding;
  19. namespace Tensorflow
  20. {
  21. /*
  22. A TensorFlow computation, represented as a dataflow graph.
  23. A `Graph` contains a set of
  24. `tf.Operation` objects,
  25. which represent units of computation; and
  26. `tf.Tensor` objects, which represent
  27. the units of data that flow between operations.
  28. A default `Graph` is always registered, and accessible by calling
  29. `tf.get_default_graph`.
  30. To add an operation to the default graph, simply call one of the functions
  31. that defines a new `Operation`:
  32. ```python
  33. c = tf.constant(4.0)
  34. assert c.graph is tf.get_default_graph()
  35. ```
  36. Another typical usage involves the
  37. `tf.Graph.as_default`
  38. context manager, which overrides the current default graph for the
  39. lifetime of the context:
  40. ```python
  41. g = tf.Graph()
  42. with g.as_default():
  43. # Define operations and tensors in `g`.
  44. c = tf.constant(30.0)
  45. assert c.graph is g
  46. ```
  47. Important note: This class *is not* thread-safe for graph construction. All
  48. operations should be created from a single thread, or external
  49. synchronization must be provided. Unless otherwise specified, all methods
  50. are not thread-safe.
  51. A `Graph` instance supports an arbitrary number of "collections"
  52. that are identified by name. For convenience when building a large
  53. graph, collections can store groups of related objects: for
  54. example, the `tf.Variable` uses a collection (named
  55. `tf.GraphKeys.GLOBAL_VARIABLES`) for
  56. all variables that are created during the construction of a graph. The caller
  57. may define additional collections by specifying a new name.
  58. */
  59. /// <summary>
  60. /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
  61. /// This leads to a low-level programming model in which you first define the dataflow graph,
  62. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
  63. /// </summary>
  64. /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
  65. public partial class Graph : DisposableObject
  66. , IEnumerable<Operation>
  67. {
  68. private Dictionary<int, ITensorOrOperation> _nodes_by_id;
  69. public Dictionary<string, ITensorOrOperation> _nodes_by_name;
  70. private Dictionary<string, int> _names_in_use;
  71. public int _version;
  72. private int _next_id_counter;
  73. private List<Operation> _unfetchable_ops = new List<Operation>();
  74. private List<Tensor> _unfeedable_tensors = new List<Tensor>();
  75. public string _name_stack = "";
  76. private string _graph_key;
  77. public string graph_key => _graph_key;
  78. public string _last_loss_reduction;
  79. public bool _is_loss_scaled_by_optimizer { get; set; }
  80. /// <summary>
  81. /// True if the graph is considered "finalized". In that case no
  82. /// new operations can be added.
  83. /// </summary>
  84. private bool _finalized = false;
  85. /// <summary>
  86. /// Arbitrary collections of objects.
  87. /// </summary>
  88. private Dictionary<string, object> _collections = new Dictionary<string, object>();
  89. public bool building_function;
  90. int _seed;
  91. public int seed
  92. {
  93. get => _seed;
  94. set
  95. {
  96. _seed = value;
  97. }
  98. }
  99. public Graph()
  100. {
  101. _handle = c_api.TF_NewGraph();
  102. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  103. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  104. _names_in_use = new Dictionary<string, int>();
  105. _graph_key = $"grap-key-{ops.uid()}/";
  106. }
  107. public Graph(IntPtr handle)
  108. {
  109. _handle = handle;
  110. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  111. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  112. _names_in_use = new Dictionary<string, int>();
  113. _graph_key = $"grap-key-{ops.uid()}/";
  114. }
  115. public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
  116. {
  117. return _as_graph_element_locked(obj, allow_tensor, allow_operation);
  118. }
  119. /// <summary>
  120. /// Returns a context manager that makes this `Graph` the default graph.
  121. /// </summary>
  122. /// <returns></returns>
  123. public Graph as_default()
  124. {
  125. return ops.set_default_graph(this);
  126. }
  127. private Tensor _as_graph_element(object obj)
  128. {
  129. if (obj is RefVariable var)
  130. return var._as_graph_element();
  131. return null;
  132. }
  133. private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
  134. {
  135. string types_str = "";
  136. if (allow_tensor && allow_operation)
  137. {
  138. types_str = "Tensor or Operation";
  139. }
  140. else if (allow_tensor)
  141. {
  142. types_str = "Tensor";
  143. }
  144. else if (allow_operation)
  145. {
  146. types_str = "Operation";
  147. }
  148. var temp_obj = _as_graph_element(obj);
  149. if (temp_obj != null)
  150. obj = temp_obj;
  151. // If obj appears to be a name...
  152. if (obj is string name)
  153. {
  154. if (name.Contains(":") && allow_tensor)
  155. {
  156. string op_name = name.Split(':')[0];
  157. int out_n = int.Parse(name.Split(':')[1]);
  158. if (_nodes_by_name.ContainsKey(op_name))
  159. return _nodes_by_name[op_name].outputs[out_n];
  160. else
  161. throw new KeyError($"The name {name} refers to a Tensor which does not " +
  162. $"exist. The operation, {op_name}, does not exist in the " +
  163. "graph.");
  164. }
  165. else if (!name.Contains(":") & allow_operation)
  166. {
  167. if (!_nodes_by_name.ContainsKey(name))
  168. throw new KeyError($"The name {name} refers to an Operation not in the graph.");
  169. return _nodes_by_name[name];
  170. }
  171. else if (!name.Contains(":") & !allow_operation)
  172. {
  173. // Looks like an Operation name but can't be an Operation.
  174. if (_nodes_by_name.ContainsKey(name))
  175. // Yep, it's an Operation name
  176. throw new ValueError($"The name {name} refers to an Operation, not a {types_str}.");
  177. else
  178. throw new ValueError(
  179. $"The name {name} looks like an (invalid) Operation name, not a {types_str}" +
  180. " Tensor names must be of the form \"<op_name>:<output_index>\".");
  181. }
  182. }
  183. if (obj is Tensor tensor && allow_tensor)
  184. {
  185. if (tensor.graph.Equals(this))
  186. {
  187. return tensor;
  188. }
  189. else
  190. {
  191. throw new Exception($"Tensor {obj} is not an element of this graph.");
  192. }
  193. }
  194. else if (obj is Operation op && allow_operation)
  195. {
  196. if (op.graph.Equals(this))
  197. {
  198. return op;
  199. }
  200. else
  201. {
  202. throw new Exception($"Operation {obj} is not an element of this graph.");
  203. }
  204. }
  205. throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
  206. }
  207. public void add_to_collection<T>(string name, T value)
  208. {
  209. _check_not_finalized();
  210. if (_collections.ContainsKey(name))
  211. (_collections[name] as List<T>).Add(value);
  212. else
  213. _collections[name] = new List<T> { value };
  214. }
  215. public void add_to_collections<T>(List<string> names, T value)
  216. {
  217. foreach (string name in names)
  218. add_to_collection(name, value);
  219. }
  220. private void _check_not_finalized()
  221. {
  222. if (_finalized)
  223. throw new RuntimeError("Graph is finalized and cannot be modified.");
  224. }
  225. public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
  226. TF_DataType[] input_types = null, string name = null,
  227. Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
  228. {
  229. if (inputs == null)
  230. inputs = new Tensor[0];
  231. if (string.IsNullOrEmpty(name))
  232. name = op_type;
  233. // If a names ends with a '/' it is a "name scope" and we use it as-is,
  234. // after removing the trailing '/'.
  235. name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
  236. var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
  237. var input_ops = inputs.Select(x => x.op).ToArray();
  238. var control_inputs = _control_dependencies_for_inputs(input_ops);
  239. var op = new Operation(node_def,
  240. this,
  241. inputs: inputs,
  242. output_types: dtypes,
  243. control_inputs: control_inputs,
  244. input_types: input_types,
  245. original_op: null,
  246. op_def: op_def);
  247. _create_op_helper(op, true);
  248. /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
  249. Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
  250. Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
  251. Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
  252. Console.WriteLine();*/
  253. return op;
  254. }
  255. public void device(string device_name)
  256. {
  257. throw new NotImplementedException("");
  258. }
  259. private void _create_op_helper(Operation op, bool compute_device = true)
  260. {
  261. _record_op_seen_by_control_dependencies(op);
  262. }
  263. public void _add_op(Operation op)
  264. {
  265. op._id_value = _next_id();
  266. _nodes_by_id[op._id] = op;
  267. _nodes_by_name[op.name] = op;
  268. _version = Math.Max(_version, op._id);
  269. }
  270. public int _next_id()
  271. {
  272. return ++_next_id_counter;
  273. }
  274. public bool is_fetchable<T>(T tensor_or_op)
  275. {
  276. if (tensor_or_op is Tensor tensor)
  277. {
  278. return !_unfetchable_ops.Contains(tensor); ;
  279. }
  280. else if (tensor_or_op is Operation op)
  281. {
  282. return !_unfetchable_ops.Contains(op);
  283. }
  284. return false;
  285. }
  286. public string get_name_scope()
  287. {
  288. return _name_stack;
  289. }
  290. public string name_scope(string name)
  291. {
  292. string new_stack = "";
  293. if (string.IsNullOrEmpty(name))
  294. new_stack = "";
  295. else if (name.EndsWith("/"))
  296. new_stack = ops.name_from_scope_name(name);
  297. else
  298. new_stack = unique_name(name);
  299. _name_stack = new_stack;
  300. return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
  301. }
  302. /// <summary>
  303. /// Return a unique operation name for `name`.
  304. ///
  305. /// Note: You rarely need to call `unique_name()` directly.Most of
  306. /// the time you just need to create `with g.name_scope()` blocks to
  307. /// generate structured names.
  308. ///
  309. /// `unique_name` is used to generate structured names, separated by
  310. /// `"/"`, to help identify operations when debugging a graph.
  311. /// Operation names are displayed in error messages reported by the
  312. /// TensorFlow runtime, and in various visualization tools such as
  313. /// TensorBoard.
  314. ///
  315. /// If `mark_as_used` is set to `True`, which is the default, a new
  316. /// unique name is created and marked as in use.If it's set to `False`,
  317. /// the unique name is returned without actually being marked as used.
  318. /// This is useful when the caller simply wants to know what the name
  319. /// to be created will be.
  320. /// </summary>
  321. /// <param name="name">The name for an operation.</param>
  322. /// <param name="mark_as_used"> Whether to mark this name as being used.</param>
  323. /// <returns>A string to be passed to `create_op()` that will be used
  324. /// to name the operation being created.</returns>
  325. public string unique_name(string name, bool mark_as_used = true)
  326. {
  327. if (name.EndsWith("basic_r_n_n_cell"))
  328. {
  329. }
  330. if (!String.IsNullOrEmpty(_name_stack))
  331. name = _name_stack + "/" + name;
  332. // For the sake of checking for names in use, we treat names as case
  333. // insensitive (e.g. foo = Foo).
  334. var name_key = name.ToLower();
  335. int i = 0;
  336. if (_names_in_use.ContainsKey(name_key))
  337. i = _names_in_use[name_key];
  338. // Increment the number for "name_key".
  339. if (mark_as_used)
  340. _names_in_use[name_key] = i + 1;
  341. if (i > 0)
  342. {
  343. // Make sure the composed name key is not already used.
  344. var base_name_key = name_key;
  345. while (_names_in_use.ContainsKey(name_key))
  346. {
  347. name_key = $"{base_name_key}_{i}";
  348. i += 1;
  349. }
  350. // Mark the composed name_key as used in case someone wants
  351. // to call unique_name("name_1").
  352. if (mark_as_used)
  353. _names_in_use[name_key] = 1;
  354. // Return the new name with the original capitalization of the given name.
  355. name = $"{name}_{i - 1}";
  356. }
  357. return name;
  358. }
  359. public TF_Output[] ReturnOutputs(IntPtr results)
  360. {
  361. IntPtr return_output_handle = IntPtr.Zero;
  362. int num_return_outputs = 0;
  363. c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
  364. TF_Output[] return_outputs = new TF_Output[num_return_outputs];
  365. unsafe
  366. {
  367. var tf_output_ptr = (TF_Output*)return_output_handle;
  368. for (int i = 0; i < num_return_outputs; i++)
  369. return_outputs[i] = *(tf_output_ptr + i);
  370. return return_outputs;
  371. }
  372. }
  373. public string[] get_all_collection_keys()
  374. {
  375. return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
  376. }
  377. public object get_collection(string name, string scope = null)
  378. {
  379. return _collections.ContainsKey(name) ? _collections[name] : null;
  380. }
  381. public List<T> get_collection<T>(string name, string scope = null)
  382. {
  383. List<T> t = default;
  384. var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
  385. switch (collection)
  386. {
  387. case List<VariableV1> list:
  388. t = list.Select(x => (T)(object)x).ToList();
  389. break;
  390. case List<ResourceVariable> list:
  391. t = list.Select(x => (T)(object)x).ToList();
  392. break;
  393. case List<RefVariable> list:
  394. t = list.Select(x => (T)(object)x).ToList();
  395. break;
  396. case List<Tensor> list:
  397. t = list.Select(x => (T)(object)x).ToList();
  398. break;
  399. case List<Operation> list:
  400. t = list.Select(x => (T)(object)x).ToList();
  401. break;
  402. default:
  403. throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
  404. }
  405. return t;
  406. }
  407. public List<T> get_collection_ref<T>(string name)
  408. {
  409. if (!_collections.ContainsKey(name))
  410. _collections[name] = new List<T>();
  411. return _collections[name] as List<T>;
  412. }
  413. public void prevent_feeding(Tensor tensor)
  414. {
  415. _unfeedable_tensors.Add(tensor);
  416. }
  417. public void prevent_fetching(Operation op)
  418. {
  419. _unfetchable_ops.Add(op);
  420. }
  421. protected override void DisposeManagedResources()
  422. {
  423. ops.default_graph_stack.remove(this);
  424. }
  425. protected override void DisposeUnmanagedResources(IntPtr handle)
  426. {
  427. c_api.TF_DeleteGraph(handle);
  428. }
  429. public Tensor get_tensor_by_tf_output(TF_Output tf_output)
  430. {
  431. var op = _get_operation_by_tf_operation(tf_output.oper);
  432. return op.outputs[tf_output.index];
  433. }
  434. /// <summary>
  435. /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
  436. /// This method may be called concurrently from multiple threads.
  437. /// </summary>
  438. /// <param name="name">The name of the `Tensor` to return.</param>
  439. /// <exception cref="KeyError">If <paramref name="name"/> does not correspond to a tensor in this graph.</exception>
  440. /// <returns>The `Tensor` with the given <paramref name="name"/>.</returns>
  441. public Tensor get_tensor_by_name(string name)
  442. {
  443. return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false);
  444. }
  445. public TensorShape GetTensorShape(TF_Output output)
  446. {
  447. var status = new Status();
  448. var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle);
  449. status.Check();
  450. if (ndim == -1)
  451. return new TensorShape();
  452. var dims = new long[ndim];
  453. c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle);
  454. status.Check();
  455. return new TensorShape(dims.Select(x => (int)x).ToArray());
  456. }
  457. string debugString = string.Empty;
  458. public override string ToString()
  459. {
  460. return $"{graph_key}, ({_handle})";
  461. /*if (string.IsNullOrEmpty(debugString))
  462. {
  463. int len = 0;
  464. debugString = c_api.TF_GraphDebugString(_handle, out len);
  465. }
  466. return debugString;*/
  467. }
  468. private IEnumerable<Operation> GetEnumerable()
  469. => c_api_util.tf_operations(this);
  470. IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
  471. => GetEnumerable().GetEnumerator();
  472. IEnumerator IEnumerable.GetEnumerator()
  473. => throw new NotImplementedException();
  474. public static implicit operator IntPtr(Graph graph)
  475. {
  476. return graph._handle;
  477. }
  478. }
  479. }