/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; namespace Tensorflow { /// /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. /// This leads to a low-level programming model in which you first define the dataflow graph, /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// https://www.tensorflow.org/guide/graphs /// /* A TensorFlow computation, represented as a dataflow graph. A `Graph` contains a set of `tf.Operation` objects, which represent units of computation; and `tf.Tensor` objects, which represent the units of data that flow between operations. A default `Graph` is always registered, and accessible by calling `tf.get_default_graph`. To add an operation to the default graph, simply call one of the functions that defines a new `Operation`: ```python c = tf.constant(4.0) assert c.graph is tf.get_default_graph() ``` Another typical usage involves the `tf.Graph.as_default` context manager, which overrides the current default graph for the lifetime of the context: ```python g = tf.Graph() with g.as_default(): # Define operations and tensors in `g`. c = tf.constant(30.0) assert c.graph is g ``` Important note: This class *is not* thread-safe for graph construction. All operations should be created from a single thread, or external synchronization must be provided. Unless otherwise specified, all methods are not thread-safe. A `Graph` instance supports an arbitrary number of "collections" that are identified by name. For convenience when building a large graph, collections can store groups of related objects: for example, the `tf.Variable` uses a collection (named `tf.GraphKeys.GLOBAL_VARIABLES`) for all variables that are created during the construction of a graph. The caller may define additional collections by specifying a new name. */ public partial class Graph : DisposableObject, IEnumerable { private Dictionary _nodes_by_id; public Dictionary _nodes_by_name; private Dictionary _names_in_use; public int _version; private int _next_id_counter; private List _unfetchable_ops = new List(); private List _unfeedable_tensors = new List(); public string _name_stack = ""; private string _graph_key; public string graph_key => _graph_key; public string _last_loss_reduction; public bool _is_loss_scaled_by_optimizer { get; set; } /// /// True if the graph is considered "finalized". In that case no /// new operations can be added. /// private bool _finalized = false; /// /// Arbitrary collections of objects. /// private Dictionary _collections = new Dictionary(); public bool building_function; public Graph() { _handle = c_api.TF_NewGraph(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; } public Graph(IntPtr handle) { _handle = handle; _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; } public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } /// /// Returns a context manager that makes this `Graph` the default graph. /// /// public Graph as_default() { return ops.set_default_graph(this); } private Tensor _as_graph_element(object obj) { if (obj is RefVariable var) return var._as_graph_element(); return null; } private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; if (allow_tensor && allow_operation) { types_str = "Tensor or Operation"; } else if (allow_tensor) { types_str = "Tensor"; } else if (allow_operation) { types_str = "Operation"; } var temp_obj = _as_graph_element(obj); if (temp_obj != null) obj = temp_obj; // If obj appears to be a name... if (obj is string name) { if (name.Contains(":") && allow_tensor) { string op_name = name.Split(':')[0]; int out_n = int.Parse(name.Split(':')[1]); if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; } else if (!name.Contains(":") & allow_operation) { if (!_nodes_by_name.ContainsKey(name)) throw new KeyError($"The name {name} refers to an Operation not in the graph."); return _nodes_by_name[name]; } else if (!name.Contains(":") & !allow_operation) { // Looks like an Operation name but can't be an Operation. if (_nodes_by_name.ContainsKey(name)) // Yep, it's an Operation name throw new ValueError($"The name {name} refers to an Operation, not a {types_str}."); else throw new ValueError( $"The name {name} looks like an (invalid) Operation name, not a {types_str}" + " Tensor names must be of the form \":\"."); } } if (obj is Tensor tensor && allow_tensor) { if (tensor.graph.Equals(this)) { return tensor; } else { throw new Exception($"Tensor {obj} is not an element of this graph."); } } else if (obj is Operation op && allow_operation) { if (op.graph.Equals(this)) { return op; } else { throw new Exception($"Operation {obj} is not an element of this graph."); } } throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } public void add_to_collection(string name, T value) { _check_not_finalized(); if (_collections.ContainsKey(name)) (_collections[name] as List).Add(value); else _collections[name] = new List { value }; } public void add_to_collections(List names, T value) { foreach (string name in names) add_to_collection(name, value); } private void _check_not_finalized() { if (_finalized) throw new RuntimeError("Graph is finalized and cannot be modified."); } public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null) { if (inputs == null) inputs = new Tensor[0]; foreach ((int idx, Tensor a) in Python.enumerate(inputs)) { } if (String.IsNullOrEmpty(name)) name = op_type; // If a names ends with a '/' it is a "name scope" and we use it as-is, // after removing the trailing '/'. name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var input_ops = inputs.Select(x => x.op).ToArray(); var control_inputs = _control_dependencies_for_inputs(input_ops); var op = new Operation(node_def, this, inputs: inputs, output_types: dtypes, control_inputs: control_inputs, input_types: input_types, original_op: null, op_def: op_def); _create_op_helper(op, true); /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); Console.WriteLine();*/ return op; } private void _create_op_helper(Operation op, bool compute_device = true) { _record_op_seen_by_control_dependencies(op); } public void _add_op(Operation op) { op._id_value = _next_id(); _nodes_by_id[op._id] = op; _nodes_by_name[op.name] = op; _version = Math.Max(_version, op._id); } public int _next_id() { return ++_next_id_counter; } public bool is_fetchable(T tensor_or_op) { if (tensor_or_op is Tensor tensor) { return !_unfetchable_ops.Contains(tensor); ; } else if (tensor_or_op is Operation op) { return !_unfetchable_ops.Contains(op); } return false; } public string get_name_scope() { return _name_stack; } public string name_scope(string name) { string new_stack = ""; if (string.IsNullOrEmpty(name)) new_stack = ""; else if (name.EndsWith("/")) new_stack = ops._name_from_scope_name(name); else new_stack = unique_name(name); _name_stack = new_stack; return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; } /// /// Return a unique operation name for `name`. /// /// Note: You rarely need to call `unique_name()` directly.Most of /// the time you just need to create `with g.name_scope()` blocks to /// generate structured names. /// /// `unique_name` is used to generate structured names, separated by /// `"/"`, to help identify operations when debugging a graph. /// Operation names are displayed in error messages reported by the /// TensorFlow runtime, and in various visualization tools such as /// TensorBoard. /// /// If `mark_as_used` is set to `True`, which is the default, a new /// unique name is created and marked as in use.If it's set to `False`, /// the unique name is returned without actually being marked as used. /// This is useful when the caller simply wants to know what the name /// to be created will be. /// /// The name for an operation. /// Whether to mark this name as being used. /// A string to be passed to `create_op()` that will be used /// to name the operation being created. public string unique_name(string name, bool mark_as_used = true) { if (!String.IsNullOrEmpty(_name_stack)) name = _name_stack + "/" + name; // For the sake of checking for names in use, we treat names as case // insensitive (e.g. foo = Foo). var name_key = name.ToLower(); int i = 0; if (_names_in_use.ContainsKey(name_key)) i = _names_in_use[name_key]; // Increment the number for "name_key". if (mark_as_used) _names_in_use[name_key] = i + 1; if (i > 0) { // Make sure the composed name key is not already used. var base_name_key = name_key; while (_names_in_use.ContainsKey(name_key)) { name_key = $"{base_name_key}_{i}"; i += 1; } // Mark the composed name_key as used in case someone wants // to call unique_name("name_1"). if (mark_as_used) _names_in_use[name_key] = 1; // Return the new name with the original capitalization of the given name. name = $"{name}_{i-1}"; } return name; } public TF_Output[] ReturnOutputs(IntPtr results) { IntPtr return_output_handle = IntPtr.Zero; int num_return_outputs = 0; c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); TF_Output[] return_outputs = new TF_Output[num_return_outputs]; for (int i = 0; i < num_return_outputs; i++) { var handle = return_output_handle + (Marshal.SizeOf() * i); return_outputs[i] = Marshal.PtrToStructure(handle); } return return_outputs; } public string[] get_all_collection_keys() { return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); } public object get_collection(string name, string scope = null) { return _collections.ContainsKey(name) ? _collections[name] : null; } public List get_collection(string name, string scope = null) { return _collections.ContainsKey(name) ? _collections[name] as List : new List(); } public object get_collection_ref(string name) { if (!_collections.ContainsKey(name)) _collections[name] = new List(); return _collections[name]; } public void prevent_feeding(Tensor tensor) { _unfeedable_tensors.Add(tensor); } public void prevent_fetching(Operation op) { _unfetchable_ops.Add(op); } protected override void DisposeManagedState() { ops.default_graph_stack.remove(this); } protected override void DisposeUnManagedState(IntPtr handle) { Console.WriteLine($"Destroy graph {handle}"); c_api.TF_DeleteGraph(handle); } /// /// Returns the with the given . /// This method may be called concurrently from multiple threads. /// /// The name of the `Tensor` to return. /// If does not correspond to a tensor in this graph. /// The `Tensor` with the given . public Tensor get_tensor_by_name(string name) { return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); } public TensorShape GetTensorShape(TF_Output output) { var status = new Status(); var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); status.Check(); if (ndim == -1) return new TensorShape(); var dims = new long[ndim]; c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); status.Check(); return new TensorShape(dims.Select(x => (int)x).ToArray()); } string debugString = string.Empty; public override string ToString() { return $"{graph_key}, ({_handle})"; /*if (string.IsNullOrEmpty(debugString)) { int len = 0; debugString = c_api.TF_GraphDebugString(_handle, out len); } return debugString;*/ } private IEnumerable GetEnumerable() => c_api_util.tf_operations(this); IEnumerator IEnumerable.GetEnumerator() => GetEnumerable().GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); } public static implicit operator IntPtr(Graph graph) { return graph._handle; } } }